diff --git a/CMakeLists.txt b/CMakeLists.txt index ddc1410..7104450 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -20,13 +20,25 @@ endif() # Download dependencies include(FetchContent) +set(SPBLAS_CPU_BACKEND OFF) +set(SPBLAS_GPU_BACKEND OFF) + if (ENABLE_ONEMKL_SYCL) + set(SPBLAS_CPU_BACKEND ON) + set(SPBLAS_GPU_BACKEND ON) find_package(MKL REQUIRED) target_link_libraries(spblas INTERFACE MKL::MKL_SYCL) # SYCL APIs set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DSPBLAS_ENABLE_ONEMKL_SYCL") + + FetchContent_Declare( + sycl_thrust + GIT_REPOSITORY https://github.com/SparseBLAS/sycl-thrust.git + GIT_TAG main) + FetchContent_MakeAvailable(sycl_thrust) endif() if (ENABLE_ARMPL) + set(SPBLAS_CPU_BACKEND ON) if (NOT DEFINED ENV{ARMPL_DIR}) message(FATAL_ERROR "Environment variable ARMPL_DIR must be set when the ArmPL is enabled.") endif() @@ -36,6 +48,7 @@ if (ENABLE_ARMPL) endif() if (ENABLE_AOCLSPARSE) + set(SPBLAS_CPU_BACKEND ON) if (NOT DEFINED ENV{AOCLSPARSE_DIR}) message(FATAL_ERROR "Environment variable AOCLSPARSE_DIR must be set when the AOCLSPARSE is enabled.") endif() @@ -81,6 +94,15 @@ if (ENABLE_CUSPARSE) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DSPBLAS_ENABLE_CUSPARSE") endif() +# If no vendor backend is enabled, enable CPU backend for reference implementation +if (NOT ENABLE_ONEMKL_SYCL AND + NOT ENABLE_ARMPL AND + NOT ENABLE_AOCLSPARSE AND + NOT ENABLE_ROCSPARSE AND + NOT ENABLE_CUSPARSE) + set(SPBLAS_CPU_BACKEND ON) +endif() + # turn on/off debug logging if (LOG_LEVEL) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DLOG_LEVEL=${LOG_LEVEL}") # SPBLAS_DEBUG | SPBLAS_WARNING | SPBLAS_TRACE | SPBLAS_INFO diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 604b50c..fcf3a82 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -3,20 +3,23 @@ function(add_example example_name) target_link_libraries(${example_name} spblas fmt) endfunction() -if (NOT SPBLAS_GPU_BACKEND) +# CPU examples +if (SPBLAS_CPU_BACKEND) add_example(simple_spmv) add_example(simple_spmm) add_example(simple_spgemm) add_example(simple_sptrsv) - add_example(matrix_opt_example) add_example(spmm_csc) -else() - add_subdirectory(device) + add_example(matrix_opt_example) endif() -if (ENABLE_ROCSPARSE) - add_subdirectory(rocsparse) -endif() -if (ENABLE_CUSPARSE) - add_subdirectory(cusparse) +# GPU examples +if (SPBLAS_GPU_BACKEND) + add_subdirectory(device) + if (ENABLE_CUSPARSE) + add_subdirectory(cusparse) + endif() + if (ENABLE_ROCSPARSE) + add_subdirectory(rocsparse) + endif() endif() diff --git a/examples/cusparse/cusparse_simple_spmv.cpp b/examples/cusparse/cusparse_simple_spmv.cpp index f8d192a..33a832f 100644 --- a/examples/cusparse/cusparse_simple_spmv.cpp +++ b/examples/cusparse/cusparse_simple_spmv.cpp @@ -76,8 +76,8 @@ int main(int argc, char** argv) { std::span y_span(d_y, m); // y = A * x - spblas::spmv_state_t state; - spblas::multiply(state, a, x_span, y_span); + spblas::operation_info_t info; + spblas::multiply(info, a, x_span, y_span); CUDA_CHECK( cudaMemcpy(y.data(), d_y, y.size() * sizeof(value_t), cudaMemcpyDefault)); diff --git a/examples/device/CMakeLists.txt b/examples/device/CMakeLists.txt index fe9c4ed..cb79234 100644 --- a/examples/device/CMakeLists.txt +++ b/examples/device/CMakeLists.txt @@ -2,12 +2,15 @@ function(add_device_example example_name) add_executable(${example_name} ${example_name}.cpp) if (ENABLE_ROCSPARSE) set_source_files_properties(${example_name}.cpp PROPERTIES LANGUAGE HIP) + target_link_libraries(${example_name} roc::rocthrust) elseif (ENABLE_CUSPARSE) target_link_libraries(${example_name} Thrust) + elseif (ENABLE_ONEMKL_SYCL) + target_link_libraries(${example_name} sycl_thrust) else() message(FATAL_ERROR "Device backend not found.") endif() target_link_libraries(${example_name} spblas fmt) endfunction() -add_device_example(simple_spmv) +add_device_example(device_spmv) diff --git a/examples/device/simple_spmv.cpp b/examples/device/device_spmv.cpp similarity index 96% rename from examples/device/simple_spmv.cpp rename to examples/device/device_spmv.cpp index b5a4162..b9fae95 100644 --- a/examples/device/simple_spmv.cpp +++ b/examples/device/device_spmv.cpp @@ -56,8 +56,7 @@ int main(int argc, char** argv) { std::span y_span(d_y.data().get(), m); // y = A * x - spblas::spmv_state_t state; - spblas::multiply(state, a, x_span, y_span); + spblas::multiply(a, x_span, y_span); thrust::copy(d_y.begin(), d_y.end(), y.begin()); diff --git a/examples/rocsparse/rocsparse_simple_spmv.cpp b/examples/rocsparse/rocsparse_simple_spmv.cpp index a1e363d..31092ff 100644 --- a/examples/rocsparse/rocsparse_simple_spmv.cpp +++ b/examples/rocsparse/rocsparse_simple_spmv.cpp @@ -76,8 +76,8 @@ int main(int argc, char** argv) { std::span y_span(d_y, m); // y = A * x - spblas::spmv_state_t state; - spblas::multiply(state, a, x_span, y_span); + spblas::operation_info_t info; + spblas::multiply(info, a, x_span, y_span); HIP_CHECK( hipMemcpy(y.data(), d_y, y.size() * sizeof(value_t), hipMemcpyDefault)); diff --git a/include/spblas/detail/operation_info_t.hpp b/include/spblas/detail/operation_info_t.hpp index 1992fc0..44947ea 100644 --- a/include/spblas/detail/operation_info_t.hpp +++ b/include/spblas/detail/operation_info_t.hpp @@ -15,6 +15,14 @@ #include #endif +#ifdef SPBLAS_ENABLE_CUSPARSE +#include +#endif + +#ifdef SPBLAS_ENABLE_ROCSPARSE +#include +#endif + namespace spblas { class operation_info_t { @@ -53,6 +61,13 @@ class operation_info_t { state_(std::move(state)) {} #endif +#ifdef SPBLAS_ENABLE_CUSPARSE + operation_info_t(index<> result_shape, offset_t result_nnz, + __cusparse::operation_state_t&& state) + : result_shape_(result_shape), result_nnz_(result_nnz), + state_(std::move(state)) {} +#endif + void update_impl_(index<> result_shape, offset_t result_nnz) { result_shape_ = result_shape; result_nnz_ = result_nnz; @@ -76,6 +91,16 @@ class operation_info_t { public: __aoclsparse::operation_state_t state_; #endif + +#ifdef SPBLAS_ENABLE_CUSPARSE +public: + __cusparse::operation_state_t state_; +#endif + +#ifdef SPBLAS_ENABLE_ROCSPARSE +public: + __rocsparse::operation_state_t state_; +#endif }; } // namespace spblas diff --git a/include/spblas/vendor/cusparse/detail/abstract_operation_state.hpp b/include/spblas/vendor/cusparse/detail/abstract_operation_state.hpp new file mode 100644 index 0000000..3eeb9bd --- /dev/null +++ b/include/spblas/vendor/cusparse/detail/abstract_operation_state.hpp @@ -0,0 +1,34 @@ +#pragma once + +#include +#include + +namespace spblas { +namespace __cusparse { + +class abstract_operation_state_t { +public: + // Common state that all operations need + cusparseHandle_t handle() const { + return handle_; + } + + // Make std::default_delete a friend so unique_ptr can delete us + friend struct std::default_delete; + +protected: + abstract_operation_state_t() { + cusparseCreate(&handle_); + } + + virtual ~abstract_operation_state_t() { + if (handle_) { + cusparseDestroy(handle_); + } + } + + cusparseHandle_t handle_; +}; + +} // namespace __cusparse +} // namespace spblas diff --git a/include/spblas/vendor/cusparse/detail/cusparse_tensors.hpp b/include/spblas/vendor/cusparse/detail/cusparse_tensors.hpp new file mode 100644 index 0000000..21e5e0c --- /dev/null +++ b/include/spblas/vendor/cusparse/detail/cusparse_tensors.hpp @@ -0,0 +1,41 @@ +#pragma once + +#include + +#include +#include +#include +#include + +namespace spblas { + +namespace __cusparse { + +template + requires __detail::is_csr_view_v +cusparseSpMatDescr_t create_cusparse_handle(M&& m) { + cusparseSpMatDescr_t mat_descr; + __cusparse::throw_if_error(cusparseCreateCsr( + &mat_descr, __backend::shape(m)[0], __backend::shape(m)[1], + m.values().size(), m.rowptr().data(), m.colind().data(), + m.values().data(), detail::cusparse_index_type_v>, + detail::cusparse_index_type_v>, + CUSPARSE_INDEX_BASE_ZERO, detail::cuda_data_type_v>)); + + return mat_descr; +} + +template + requires __ranges::contiguous_range +cusparseDnVecDescr_t create_cusparse_handle(V&& v) { + cusparseDnVecDescr_t vec_descr; + __cusparse::throw_if_error( + cusparseCreateDnVec(&vec_descr, __backend::shape(v), __ranges::data(v), + detail::cuda_data_type_v>)); + + return vec_descr; +} + +} // namespace __cusparse + +} // namespace spblas diff --git a/include/spblas/vendor/cusparse/detail/get_transpose.hpp b/include/spblas/vendor/cusparse/detail/get_transpose.hpp new file mode 100644 index 0000000..f68c3fc --- /dev/null +++ b/include/spblas/vendor/cusparse/detail/get_transpose.hpp @@ -0,0 +1,32 @@ +#pragma once + +#include +#include + +namespace spblas { +namespace __cusparse { + +// +// Takes in a CSR or CSR_transpose (aka CSC) or CSC or CSC_transpose +// and returns the cusparseOperation_t value associated with it being +// represented in the CSR format +// +// CSR = CSR + NON_TRANSPOSE +// CSR_transpose = CSR + TRANSPOSE +// CSC = CSR + TRANSPOSE +// CSC_transpose = CSR + NON_TRANSPOSE +// +template +cusparseOperation_t get_transpose(M&& m) { + static_assert(__detail::has_csr_base || __detail::has_csc_base); + if constexpr (__detail::has_base) { + return get_transpose(m.base()); + } else if constexpr (__detail::is_csr_view_v) { + return CUSPARSE_OPERATION_NON_TRANSPOSE; + } else if constexpr (__detail::is_csc_view_v) { + return CUSPARSE_OPERATION_TRANSPOSE; + } +} + +} // namespace __cusparse +} // namespace spblas diff --git a/include/spblas/vendor/cusparse/detail/spmv_state_t.hpp b/include/spblas/vendor/cusparse/detail/spmv_state_t.hpp new file mode 100644 index 0000000..5e0cf64 --- /dev/null +++ b/include/spblas/vendor/cusparse/detail/spmv_state_t.hpp @@ -0,0 +1,55 @@ +#pragma once + +#include +#include + +#include "abstract_operation_state.hpp" + +namespace spblas { +namespace __cusparse { + +class spmv_state_t : public abstract_operation_state_t { +public: + spmv_state_t() = default; + ~spmv_state_t() { + if (a_descr_) { + cusparseDestroySpMat(a_descr_); + } + if (b_descr_) { + cusparseDestroyDnVec(b_descr_); + } + if (c_descr_) { + cusparseDestroyDnVec(c_descr_); + } + } + + // Accessors for the descriptors + cusparseSpMatDescr_t a_descriptor() const { + return a_descr_; + } + cusparseDnVecDescr_t b_descriptor() const { + return b_descr_; + } + cusparseDnVecDescr_t c_descriptor() const { + return c_descr_; + } + + // Setters for the descriptors + void set_a_descriptor(cusparseSpMatDescr_t descr) { + a_descr_ = descr; + } + void set_b_descriptor(cusparseDnVecDescr_t descr) { + b_descr_ = descr; + } + void set_c_descriptor(cusparseDnVecDescr_t descr) { + c_descr_ = descr; + } + +private: + cusparseSpMatDescr_t a_descr_ = nullptr; + cusparseDnVecDescr_t b_descr_ = nullptr; + cusparseDnVecDescr_t c_descr_ = nullptr; +}; + +} // namespace __cusparse +} // namespace spblas diff --git a/include/spblas/vendor/cusparse/multiply.hpp b/include/spblas/vendor/cusparse/multiply.hpp index 0f54549..4b68a98 100644 --- a/include/spblas/vendor/cusparse/multiply.hpp +++ b/include/spblas/vendor/cusparse/multiply.hpp @@ -1,123 +1,3 @@ #pragma once -#include -#include -#include - -#include -#include - -#include -#include - -#include "cuda_allocator.hpp" -#include "exception.hpp" -#include "types.hpp" - -namespace spblas { - -class spmv_state_t { -public: - spmv_state_t() : spmv_state_t(cusparse::cuda_allocator{}) {} - - spmv_state_t(cusparse::cuda_allocator alloc) - : alloc_(alloc), buffer_size_(0), workspace_(nullptr) { - cusparseHandle_t handle; - __cusparse::throw_if_error(cusparseCreate(&handle)); - if (auto stream = alloc.stream()) { - cusparseSetStream(handle, stream); - } - handle_ = handle_manager(handle, [](cusparseHandle_t handle) { - __cusparse::throw_if_error(cusparseDestroy(handle)); - }); - } - - spmv_state_t(cusparse::cuda_allocator alloc, cusparseHandle_t handle) - : alloc_(alloc), buffer_size_(0), workspace_(nullptr) { - handle_ = handle_manager(handle, [](cusparseHandle_t handle) { - // it is provided by user, we do not delete it at all. - }); - } - - ~spmv_state_t() { - alloc_.deallocate(workspace_, buffer_size_); - } - - template - requires __detail::has_csr_base && - __detail::has_contiguous_range_base && - __ranges::contiguous_range - void multiply(A&& a, B&& b, C&& c) { - auto a_base = __detail::get_ultimate_base(a); - auto b_base = __detail::get_ultimate_base(b); - using matrix_type = decltype(a_base); - using input_type = decltype(b_base); - using output_type = std::remove_reference_t; - using value_type = typename matrix_type::scalar_type; - - auto alpha_optional = __detail::get_scaling_factor(a, b); - tensor_scalar_t alpha = alpha_optional.value_or(1); - auto handle = this->handle_.get(); - - cusparseSpMatDescr_t mat; - __cusparse::throw_if_error(cusparseCreateCsr( - &mat, __backend::shape(a_base)[0], __backend::shape(a_base)[1], - a_base.values().size(), a_base.rowptr().data(), a_base.colind().data(), - a_base.values().data(), - to_cusparse_indextype(), - to_cusparse_indextype(), - CUSPARSE_INDEX_BASE_ZERO, to_cuda_datatype())); - - cusparseDnVecDescr_t vecb; - cusparseDnVecDescr_t vecc; - __cusparse::throw_if_error(cusparseCreateDnVec( - &vecb, b_base.size(), b_base.data(), - to_cuda_datatype())); - __cusparse::throw_if_error(cusparseCreateDnVec( - &vecc, c.size(), c.data(), - to_cuda_datatype())); - - value_type alpha_val = alpha; - value_type beta = 0.0; - long unsigned int buffer_size = 0; - // TODO: create a compute type for mixed precision computation - __cusparse::throw_if_error(cusparseSpMV_bufferSize( - handle, CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha_val, mat, vecb, &beta, - vecc, to_cuda_datatype(), CUSPARSE_SPMV_ALG_DEFAULT, - &buffer_size)); - // only allocate the new workspace when the requiring workspace larger than - // current - if (buffer_size > this->buffer_size_) { - this->alloc_.deallocate(this->workspace_, buffer_size_); - this->buffer_size_ = buffer_size; - this->workspace_ = this->alloc_.allocate(buffer_size); - } - - __cusparse::throw_if_error( - cusparseSpMV(handle, CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha_val, mat, - vecb, &beta, vecc, to_cuda_datatype(), - CUSPARSE_SPMV_ALG_DEFAULT, this->workspace_)); - __cusparse::throw_if_error(cusparseDestroySpMat(mat)); - __cusparse::throw_if_error(cusparseDestroyDnVec(vecc)); - __cusparse::throw_if_error(cusparseDestroyDnVec(vecb)); - } - -private: - using handle_manager = - std::unique_ptr::element_type, - std::function>; - handle_manager handle_; - cusparse::cuda_allocator alloc_; - long unsigned int buffer_size_; - char* workspace_; -}; - -template - requires __detail::has_csr_base && - __detail::has_contiguous_range_base && - __ranges::contiguous_range -void multiply(spmv_state_t& spmv_state, A&& a, B&& b, C&& c) { - spmv_state.multiply(a, b, c); -} - -} // namespace spblas +#include "spmv_impl.hpp" diff --git a/include/spblas/vendor/cusparse/operation_state_t.hpp b/include/spblas/vendor/cusparse/operation_state_t.hpp new file mode 100644 index 0000000..df346db --- /dev/null +++ b/include/spblas/vendor/cusparse/operation_state_t.hpp @@ -0,0 +1,41 @@ +#pragma once + +#include "detail/abstract_operation_state.hpp" +#include + +namespace spblas { + +namespace __cusparse { + +class operation_state_t { +public: + operation_state_t() = default; + operation_state_t(std::unique_ptr&& state) + : state_(std::move(state)) {} + + // Move-only + operation_state_t(operation_state_t&&) = default; + operation_state_t& operator=(operation_state_t&&) = default; + + // No copying + operation_state_t(const operation_state_t&) = delete; + operation_state_t& operator=(const operation_state_t&) = delete; + + // Access the underlying state + template + T* get_state() { + return dynamic_cast(state_.get()); + } + + template + const T* get_state() const { + return dynamic_cast(state_.get()); + } + +private: + std::unique_ptr state_; +}; + +} // namespace __cusparse + +} // namespace spblas diff --git a/include/spblas/vendor/cusparse/spmv_impl.hpp b/include/spblas/vendor/cusparse/spmv_impl.hpp new file mode 100644 index 0000000..33a5825 --- /dev/null +++ b/include/spblas/vendor/cusparse/spmv_impl.hpp @@ -0,0 +1,96 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace spblas { + +template + requires(__detail::has_csr_base && + __detail::has_contiguous_range_base && + __ranges::contiguous_range && + detail::has_valid_cusparse_matrix_types_v && + detail::has_valid_cusparse_vector_types_v && + detail::has_valid_cusparse_vector_types_v) +void multiply(operation_info_t& info, A&& a, X&& x, Y&& y) { + log_trace(""); + + auto x_base = __detail::get_ultimate_base(x); + auto a_base = __detail::get_ultimate_base(a); + + auto alpha_optional = __detail::get_scaling_factor(a, x); + tensor_scalar_t alpha = alpha_optional.value_or(1); + tensor_scalar_t beta = 0; + + // Get or create state + auto state = info.state_.get_state<__cusparse::spmv_state_t>(); + if (!state) { + info.state_ = __cusparse::operation_state_t( + std::make_unique<__cusparse::spmv_state_t>()); + state = info.state_.get_state<__cusparse::spmv_state_t>(); + } + + // Create or get matrix descriptor + if (!state->a_descriptor()) { + cusparseSpMatDescr_t a_descr = __cusparse::create_cusparse_handle(a_base); + state->set_a_descriptor(a_descr); + } + + // Create vector descriptors + cusparseDnVecDescr_t b_descr = __cusparse::create_cusparse_handle(x_base); + cusparseDnVecDescr_t c_descr = __cusparse::create_cusparse_handle(y); + state->set_b_descriptor(b_descr); + state->set_c_descriptor(c_descr); + + // Get operation type based on matrix format + auto a_transpose = __cusparse::get_transpose(a); + + // Get buffer size + size_t buffer_size; + __cusparse::throw_if_error(cusparseSpMV_bufferSize( + state->handle(), a_transpose, &alpha, state->a_descriptor(), + state->b_descriptor(), &beta, state->c_descriptor(), + detail::cuda_data_type_v>, CUSPARSE_SPMV_ALG_DEFAULT, + &buffer_size)); + + // Allocate buffer if needed + void* buffer = nullptr; + if (buffer_size > 0) { + cudaMalloc(&buffer, buffer_size); + } + + // Execute SpMV + __cusparse::throw_if_error( + cusparseSpMV(state->handle(), a_transpose, &alpha, state->a_descriptor(), + state->b_descriptor(), &beta, state->c_descriptor(), + detail::cuda_data_type_v>, + CUSPARSE_SPMV_ALG_DEFAULT, buffer)); + + // Free buffer if allocated + if (buffer) { + cudaFree(buffer); + } +} + +template + requires(__detail::has_csr_base && + __detail::has_contiguous_range_base && + __ranges::contiguous_range && + detail::has_valid_cusparse_matrix_types_v && + detail::has_valid_cusparse_vector_types_v && + detail::has_valid_cusparse_vector_types_v) +void multiply(A&& a, X&& x, Y&& y) { + operation_info_t info; + multiply(info, std::forward(a), std::forward(x), std::forward(y)); +} + +} // namespace spblas diff --git a/include/spblas/vendor/cusparse/type_validation.hpp b/include/spblas/vendor/cusparse/type_validation.hpp new file mode 100644 index 0000000..29ec238 --- /dev/null +++ b/include/spblas/vendor/cusparse/type_validation.hpp @@ -0,0 +1,20 @@ +#pragma once + +#include +#include + +namespace spblas { +namespace detail { + +template +static constexpr bool has_valid_cusparse_matrix_types_v = + is_valid_cusparse_scalar_type_v> && + is_valid_cusparse_index_type_v> && + is_valid_cusparse_index_type_v>; + +template +static constexpr bool has_valid_cusparse_vector_types_v = + is_valid_cusparse_scalar_type_v>; + +} // namespace detail +} // namespace spblas diff --git a/include/spblas/vendor/cusparse/types.hpp b/include/spblas/vendor/cusparse/types.hpp index 3b44081..8fb818f 100644 --- a/include/spblas/vendor/cusparse/types.hpp +++ b/include/spblas/vendor/cusparse/types.hpp @@ -2,6 +2,7 @@ #include #include +#include #include #include @@ -13,66 +14,68 @@ using offset_t = index_t; namespace detail { -/** - * mapping the type to cudaDataType_t - */ template -struct cuda_datatype_traits {}; +constexpr static bool is_valid_cusparse_scalar_type_v = + std::is_floating_point_v || std::is_same_v || + std::is_same_v; -#define MAP_CUDA_DATATYPE(_type, _value) \ - template <> \ - struct cuda_datatype_traits<_type> { \ - constexpr static cudaDataType_t value = _value; \ - } +template +constexpr static bool is_valid_cusparse_index_type_v = + std::is_same_v || std::is_same_v; -MAP_CUDA_DATATYPE(float, CUDA_R_32F); -MAP_CUDA_DATATYPE(double, CUDA_R_64F); -MAP_CUDA_DATATYPE(std::complex, CUDA_C_32F); -MAP_CUDA_DATATYPE(std::complex, CUDA_C_64F); +template +struct cuda_data_type; + +template <> +struct cuda_data_type { + constexpr static cudaDataType_t value = CUDA_R_32F; +}; + +template <> +struct cuda_data_type { + constexpr static cudaDataType_t value = CUDA_R_64F; +}; + +template <> +struct cuda_data_type> { + constexpr static cudaDataType_t value = CUDA_C_32F; +}; + +template <> +struct cuda_data_type> { + constexpr static cudaDataType_t value = CUDA_C_64F; +}; + +template <> +struct cuda_data_type { + constexpr static cudaDataType_t value = CUDA_R_8I; +}; + +template <> +struct cuda_data_type { + constexpr static cudaDataType_t value = CUDA_R_32I; +}; -#undef MAP_CUDA_DATATYPE +template +constexpr static cudaDataType_t cuda_data_type_v = cuda_data_type::value; -/** - * mapping the type to cusparseIndexType_t - */ template -struct cusparse_indextype_traits {}; +struct cuda_index_type; -#define MAP_CUSPARSE_INDEXTYPE(_type, _value) \ - template <> \ - struct cusparse_indextype_traits<_type> { \ - constexpr static cusparseIndexType_t value = _value; \ - } +template <> +struct cuda_index_type { + constexpr static cusparseIndexType_t value = CUSPARSE_INDEX_32I; +}; -MAP_CUSPARSE_INDEXTYPE(std::int32_t, CUSPARSE_INDEX_32I); -MAP_CUSPARSE_INDEXTYPE(std::int64_t, CUSPARSE_INDEX_64I); +template <> +struct cuda_index_type { + constexpr static cusparseIndexType_t value = CUSPARSE_INDEX_64I; +}; -#undef MAP_CUSPARSE_INDEXTYPE +template +constexpr static cusparseIndexType_t cusparse_index_type_v = + cuda_index_type::value; } // namespace detail -/** - * This is an alias for the `cudaDataType_t` equivalent of `T`. - * - * @tparam T a type - * - * @returns the actual `cudaDataType_t` - */ -template -constexpr cudaDataType_t to_cuda_datatype() { - return detail::cuda_datatype_traits::value; -} - -/** - * This is an alias for the `cudaIndexType_t` equivalent of `T`. - * - * @tparam T a type - * - * @returns the actual `cusparseIndexType_t` - */ -template -constexpr cusparseIndexType_t to_cusparse_indextype() { - return detail::cusparse_indextype_traits::value; -} - } // namespace spblas diff --git a/include/spblas/vendor/onemkl_sycl/detail/detail.hpp b/include/spblas/vendor/onemkl_sycl/detail/detail.hpp index da2fc72..1002c3a 100644 --- a/include/spblas/vendor/onemkl_sycl/detail/detail.hpp +++ b/include/spblas/vendor/onemkl_sycl/detail/detail.hpp @@ -1,4 +1,6 @@ #pragma once #include "create_matrix_handle.hpp" +#include "execution_policy.hpp" #include "get_matrix_handle.hpp" +#include "get_queue.hpp" diff --git a/include/spblas/vendor/onemkl_sycl/detail/execution_policy.hpp b/include/spblas/vendor/onemkl_sycl/detail/execution_policy.hpp new file mode 100644 index 0000000..f9b233e --- /dev/null +++ b/include/spblas/vendor/onemkl_sycl/detail/execution_policy.hpp @@ -0,0 +1,52 @@ +#pragma once + +#include +#include + +namespace spblas { + +namespace mkl { + +class parallel_policy { +public: + parallel_policy() {} + + template + sycl::queue get_queue(T* ptr) const { + return spblas::__mkl::get_pointer_queue(ptr); + } + + sycl::queue get_queue() const { + return sycl::queue(sycl::default_selector_v); + } +}; + +class device_policy { +public: + device_policy(const sycl::queue& queue) : queue_(queue) {} + + sycl::queue& get_queue() { + return queue_; + } + + const sycl::queue& get_queue() const { + return queue_; + } + + sycl::device get_device() const { + return queue_.get_device(); + } + + sycl::context get_context() const { + return queue_.get_context(); + } + +private: + sycl::queue queue_; +}; + +inline parallel_policy par; + +} // namespace mkl + +} // namespace spblas diff --git a/include/spblas/vendor/onemkl_sycl/detail/get_pointer_device.hpp b/include/spblas/vendor/onemkl_sycl/detail/get_pointer_device.hpp new file mode 100644 index 0000000..d6de681 --- /dev/null +++ b/include/spblas/vendor/onemkl_sycl/detail/get_pointer_device.hpp @@ -0,0 +1,46 @@ +#pragma once + +#include +#include + +namespace spblas { + +namespace __mkl { + +inline std::vector global_contexts_; + +template +std::pair get_pointer_device(T* ptr) { + if (global_contexts_.empty()) { + for (auto&& platform : sycl::platform::get_platforms()) { + sycl::context context(platform.get_devices()); + + global_contexts_.push_back(context); + } + } + + for (auto&& context : global_contexts_) { + try { + sycl::device device = sycl::get_pointer_device(ptr, context); + return {device, context}; + } catch (...) { + } + } + + throw std::runtime_error( + "get_pointer_device: could not locate device corresponding to pointer"); +} + +template +sycl::queue get_pointer_queue(T* ptr) { + try { + auto&& [device, context] = get_pointer_device(ptr); + return sycl::queue(context, device); + } catch (...) { + return sycl::queue(sycl::cpu_selector_v); + } +} + +} // namespace __mkl + +} // namespace spblas diff --git a/include/spblas/vendor/onemkl_sycl/detail/get_queue.hpp b/include/spblas/vendor/onemkl_sycl/detail/get_queue.hpp new file mode 100644 index 0000000..354da6c --- /dev/null +++ b/include/spblas/vendor/onemkl_sycl/detail/get_queue.hpp @@ -0,0 +1,40 @@ +#pragma once + +#include + +namespace spblas { + +namespace __mkl { + +template +sycl::queue get_queue(const spblas::mkl::parallel_policy& policy, T* ptr) { + return policy.get_queue(ptr); +} + +template +sycl::queue& get_queue(spblas::mkl::device_policy& policy, T* ptr) { + return policy.get_queue(); +} + +} // namespace __mkl + +} // namespace spblas + +#if __has_include() + +#include + +namespace spblas { + +namespace __mkl { + +template +sycl::queue& get_queue(thrust::execution_policy& policy, T* ptr) { + return policy.get_queue(); +} + +} // namespace __mkl + +} // namespace spblas + +#endif diff --git a/include/spblas/vendor/onemkl_sycl/spgemm_impl.hpp b/include/spblas/vendor/onemkl_sycl/spgemm_impl.hpp index 6341a07..9e67a70 100644 --- a/include/spblas/vendor/onemkl_sycl/spgemm_impl.hpp +++ b/include/spblas/vendor/onemkl_sycl/spgemm_impl.hpp @@ -11,6 +11,7 @@ #include #include +#include #include // @@ -26,18 +27,20 @@ namespace spblas { -template +template requires(__detail::has_csr_base || __detail::has_csc_base) && (__detail::has_csr_base || __detail::has_csc_base) && __detail::is_csr_view_v -operation_info_t multiply_compute(A&& a, B&& b, C&& c) { +operation_info_t + multiply_compute(ExecutionPolicy&& policy, A&& a, B&& b, C&& c) { log_trace(""); using oneapi::mkl::transpose; using oneapi::mkl::sparse::matmat_request; using oneapi::mkl::sparse::matrix_view_descr; - sycl::queue q(sycl::cpu_selector_v); + auto a_data = __detail::get_ultimate_base(a).values().data(); + auto&& q = __mkl::get_queue(policy, a_data); auto a_handle = __mkl::get_matrix_handle(q, a); auto b_handle = __mkl::get_matrix_handle(q, b); @@ -108,21 +111,22 @@ operation_info_t multiply_compute(A&& a, B&& b, C&& c) { c_handle, nullptr, descr, (void*) c_rowptr, q}}; } -template +template requires(__detail::has_csr_base || __detail::has_csc_base) && (__detail::has_csr_base || __detail::has_csc_base) && __detail::is_csr_view_v -void multiply_fill(operation_info_t& info, A&& a, B&& b, C&& c) { - +void multiply_fill(ExecutionPolicy&& policy, operation_info_t& info, A&& a, + B&& b, C&& c) { log_trace(""); auto alpha_optional = __detail::get_scaling_factor(a, b); tensor_scalar_t alpha = alpha_optional.value_or(1); using oneapi::mkl::sparse::matmat_request; - sycl::queue q(sycl::cpu_selector_v); using O = tensor_offset_t; + auto&& q = info.state_.q; + O* c_rowptr = (O*) info.state_.c_rowptr; auto a_handle = __mkl::get_matrix_handle(q, a, info.state_.a_handle); @@ -156,6 +160,24 @@ void multiply_fill(operation_info_t& info, A&& a, B&& b, C&& c) { } } +template + requires(__detail::has_csr_base || __detail::has_csc_base) && + (__detail::has_csr_base || __detail::has_csc_base) && + __detail::is_csr_view_v +operation_info_t multiply_compute(A&& a, B&& b, C&& c) { + return multiply_compute(mkl::par, std::forward(a), std::forward(b), + std::forward(c)); +} + +template + requires(__detail::has_csr_base || __detail::has_csc_base) && + (__detail::has_csr_base || __detail::has_csc_base) && + __detail::is_csr_view_v +void multiply_fill(operation_info_t& info, A&& a, B&& b, C&& c) { + multiply_fill(mkl::par, info, std::forward(a), std::forward(b), + std::forward(c)); +} + template requires(__detail::has_csr_base || __detail::has_csc_base) && (__detail::has_csr_base || __detail::has_csc_base) && diff --git a/include/spblas/vendor/onemkl_sycl/spmm_impl.hpp b/include/spblas/vendor/onemkl_sycl/spmm_impl.hpp index 75cfb20..856e711 100644 --- a/include/spblas/vendor/onemkl_sycl/spmm_impl.hpp +++ b/include/spblas/vendor/onemkl_sycl/spmm_impl.hpp @@ -27,7 +27,7 @@ namespace spblas { -template +template requires( (__detail::has_csr_base || __detail::has_csc_base) && __detail::has_mdspan_matrix_base && __detail::is_matrix_mdspan_v && @@ -35,14 +35,15 @@ template __mdspan::layout_right> && std::is_same_v::layout_type, __mdspan::layout_right>) -void multiply(A&& a, X&& x, Y&& y) { +void multiply(ExecutionPolicy&& policy, A&& a, X&& x, Y&& y) { log_trace(""); auto x_base = __detail::get_ultimate_base(x); auto alpha_optional = __detail::get_scaling_factor(a, x); tensor_scalar_t alpha = alpha_optional.value_or(1); - sycl::queue q(sycl::cpu_selector_v); + auto a_data = __detail::get_ultimate_base(a).values().data(); + auto&& q = __mkl::get_queue(policy, a_data); auto a_handle = __mkl::get_matrix_handle(q, a); auto a_transpose = __mkl::get_transpose(a); @@ -58,4 +59,17 @@ void multiply(A&& a, X&& x, Y&& y) { } } +template + requires( + (__detail::has_csr_base || __detail::has_csc_base) && + __detail::has_mdspan_matrix_base && __detail::is_matrix_mdspan_v && + std::is_same_v::layout_type, + __mdspan::layout_right> && + std::is_same_v::layout_type, + __mdspan::layout_right>) +void multiply(A&& a, X&& x, Y&& y) { + multiply(mkl::par, std::forward(a), std::forward(x), + std::forward(y)); +} + } // namespace spblas diff --git a/include/spblas/vendor/onemkl_sycl/spmv_impl.hpp b/include/spblas/vendor/onemkl_sycl/spmv_impl.hpp index 9035c8c..33e34b5 100644 --- a/include/spblas/vendor/onemkl_sycl/spmv_impl.hpp +++ b/include/spblas/vendor/onemkl_sycl/spmv_impl.hpp @@ -26,18 +26,20 @@ namespace spblas { -template +template requires((__detail::has_csr_base || __detail::has_csc_base) && __detail::has_contiguous_range_base && __ranges::contiguous_range) -void multiply(A&& a, X&& x, Y&& y) { +void multiply(ExecutionPolicy&& policy, A&& a, X&& x, Y&& y) { log_trace(""); auto x_base = __detail::get_ultimate_base(x); auto alpha_optional = __detail::get_scaling_factor(a, x); tensor_scalar_t alpha = alpha_optional.value_or(1); - sycl::queue q(sycl::cpu_selector_v); + auto a_data = __detail::get_ultimate_base(a).values().data(); + + auto&& q = __mkl::get_queue(policy, a_data); auto a_handle = __mkl::get_matrix_handle(q, a); auto a_transpose = __mkl::get_transpose(a); @@ -51,4 +53,13 @@ void multiply(A&& a, X&& x, Y&& y) { } } +template + requires((__detail::has_csr_base || __detail::has_csc_base) && + __detail::has_contiguous_range_base && + __ranges::contiguous_range) +void multiply(A&& a, X&& x, Y&& y) { + multiply(mkl::par, std::forward(a), std::forward(x), + std::forward(y)); +} + } // namespace spblas diff --git a/include/spblas/vendor/rocsparse/detail/abstract_operation_state.hpp b/include/spblas/vendor/rocsparse/detail/abstract_operation_state.hpp new file mode 100644 index 0000000..48aee79 --- /dev/null +++ b/include/spblas/vendor/rocsparse/detail/abstract_operation_state.hpp @@ -0,0 +1,34 @@ +#pragma once + +#include +#include + +namespace spblas { +namespace __rocsparse { + +class abstract_operation_state_t { +public: + // Common state that all operations need + rocsparse_handle handle() const { + return handle_; + } + + // Make std::default_delete a friend so unique_ptr can delete us + friend struct std::default_delete; + +protected: + abstract_operation_state_t() { + rocsparse_create_handle(&handle_); + } + + virtual ~abstract_operation_state_t() { + if (handle_) { + rocsparse_destroy_handle(handle_); + } + } + + rocsparse_handle handle_; +}; + +} // namespace __rocsparse +} // namespace spblas diff --git a/include/spblas/vendor/rocsparse/detail/get_transpose.hpp b/include/spblas/vendor/rocsparse/detail/get_transpose.hpp new file mode 100644 index 0000000..9279f66 --- /dev/null +++ b/include/spblas/vendor/rocsparse/detail/get_transpose.hpp @@ -0,0 +1,32 @@ +#pragma once + +#include +#include + +namespace spblas { +namespace __rocsparse { + +// +// Takes in a CSR or CSR_transpose (aka CSC) or CSC or CSC_transpose +// and returns the rocsparse_operation value associated with it being +// represented in the CSR format +// +// CSR = CSR + NON_TRANSPOSE +// CSR_transpose = CSR + TRANSPOSE +// CSC = CSR + TRANSPOSE +// CSC_transpose = CSR + NON_TRANSPOSE +// +template +rocsparse_operation get_transpose(M&& m) { + static_assert(__detail::has_csr_base || __detail::has_csc_base); + if constexpr (__detail::has_base) { + return get_transpose(m.base()); + } else if constexpr (__detail::is_csr_view_v) { + return rocsparse_operation_none; + } else if constexpr (__detail::is_csc_view_v) { + return rocsparse_operation_transpose; + } +} + +} // namespace __rocsparse +} // namespace spblas diff --git a/include/spblas/vendor/rocsparse/detail/rocsparse_tensors.hpp b/include/spblas/vendor/rocsparse/detail/rocsparse_tensors.hpp new file mode 100644 index 0000000..4cde18f --- /dev/null +++ b/include/spblas/vendor/rocsparse/detail/rocsparse_tensors.hpp @@ -0,0 +1,40 @@ +#pragma once + +#include + +#include +#include +#include +#include + +namespace spblas { +namespace __rocsparse { + +template + requires __detail::is_csr_view_v +rocsparse_spmat_descr create_rocsparse_handle(M&& m) { + rocsparse_spmat_descr mat_descr; + throw_if_error(rocsparse_create_csr_descr( + &mat_descr, __backend::shape(m)[0], __backend::shape(m)[1], + m.values().size(), m.rowptr().data(), m.colind().data(), + m.values().data(), detail::rocsparse_index_type_v>, + detail::rocsparse_index_type_v>, + rocsparse_index_base_zero, + detail::rocsparse_data_type_v>)); + + return mat_descr; +} + +template + requires __ranges::contiguous_range +rocsparse_dnvec_descr create_rocsparse_handle(V&& v) { + rocsparse_dnvec_descr vec_descr; + throw_if_error(rocsparse_create_dnvec_descr( + &vec_descr, __backend::shape(v), __ranges::data(v), + detail::rocsparse_data_type_v>)); + + return vec_descr; +} + +} // namespace __rocsparse +} // namespace spblas diff --git a/include/spblas/vendor/rocsparse/detail/spmv_impl.hpp b/include/spblas/vendor/rocsparse/detail/spmv_impl.hpp new file mode 100644 index 0000000..5ade039 --- /dev/null +++ b/include/spblas/vendor/rocsparse/detail/spmv_impl.hpp @@ -0,0 +1,84 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace spblas { + +template + requires(__detail::has_csr_base && + __detail::has_contiguous_range_base && + __ranges::contiguous_range && + detail::has_valid_rocsparse_matrix_types_v && + detail::has_valid_rocsparse_vector_types_v && + detail::has_valid_rocsparse_vector_types_v) +void multiply(operation_info_t& info, A&& a, B&& b, C&& c) { + auto a_base = __detail::get_ultimate_base(a); + auto b_base = __detail::get_ultimate_base(b); + + auto alpha_optional = __detail::get_scaling_factor(a, b); + tensor_scalar_t alpha = alpha_optional.value_or(1); + tensor_scalar_t beta = 0; + + // Get or create state + auto state = info.state_.get_state<__rocsparse::spmv_state_t>(); + if (!state) { + info.state_ = __rocsparse::operation_state_t( + std::make_unique<__rocsparse::spmv_state_t>()); + state = info.state_.get_state<__rocsparse::spmv_state_t>(); + } + + // Create descriptors + auto a_descr = __rocsparse::create_rocsparse_handle(a_base); + auto b_descr = __rocsparse::create_rocsparse_handle(b_base); + auto c_descr = __rocsparse::create_rocsparse_handle(c); + + state->set_a_descriptor(a_descr); + state->set_b_descriptor(b_descr); + state->set_c_descriptor(c_descr); + + // Get operation type based on matrix format + auto a_transpose = __rocsparse::get_transpose(a); + + // Get buffer size + size_t buffer_size = 0; + __rocsparse::throw_if_error(rocsparse_spmv( + state->handle(), a_transpose, &alpha, state->a_descriptor(), + state->b_descriptor(), &beta, state->c_descriptor(), + detail::rocsparse_data_type_v>, + rocsparse_spmv_alg_csr_stream, rocsparse_spmv_stage_buffer_size, + &buffer_size, nullptr)); + + // Allocate buffer if needed + state->allocate_workspace(buffer_size); + + // Execute SpMV + __rocsparse::throw_if_error(rocsparse_spmv( + state->handle(), a_transpose, &alpha, state->a_descriptor(), + state->b_descriptor(), &beta, state->c_descriptor(), + detail::rocsparse_data_type_v>, + rocsparse_spmv_alg_csr_stream, rocsparse_spmv_stage_compute, &buffer_size, + state->workspace())); +} + +template + requires(__detail::has_csr_base && + __detail::has_contiguous_range_base && + __ranges::contiguous_range && + detail::has_valid_rocsparse_matrix_types_v && + detail::has_valid_rocsparse_vector_types_v && + detail::has_valid_rocsparse_vector_types_v) +void multiply(A&& a, B&& b, C&& c) { + operation_info_t info; + multiply(info, std::forward(a), std::forward(b), std::forward(c)); +} + +} // namespace spblas diff --git a/include/spblas/vendor/rocsparse/detail/spmv_state_t.hpp b/include/spblas/vendor/rocsparse/detail/spmv_state_t.hpp new file mode 100644 index 0000000..5ee9689 --- /dev/null +++ b/include/spblas/vendor/rocsparse/detail/spmv_state_t.hpp @@ -0,0 +1,98 @@ +#pragma once + +#include +#include + +#include "../hip_allocator.hpp" +#include "abstract_operation_state.hpp" + +namespace spblas { +namespace __rocsparse { + +class spmv_state_t : public abstract_operation_state_t { +public: + spmv_state_t() : spmv_state_t(rocsparse::hip_allocator{}) {} + + spmv_state_t(rocsparse::hip_allocator alloc) + : alloc_(alloc), buffer_size_(0), workspace_(nullptr), a_descr_(nullptr), + b_descr_(nullptr), c_descr_(nullptr) {} + + ~spmv_state_t() { + if (workspace_) { + alloc_.deallocate(workspace_, buffer_size_); + } + if (a_descr_) { + rocsparse_destroy_spmat_descr(a_descr_); + } + if (b_descr_) { + rocsparse_destroy_dnvec_descr(b_descr_); + } + if (c_descr_) { + rocsparse_destroy_dnvec_descr(c_descr_); + } + } + + // Workspace management + void* workspace() const { + return workspace_; + } + size_t buffer_size() const { + return buffer_size_; + } + + void allocate_workspace(size_t size) { + if (size > buffer_size_) { + if (workspace_) { + alloc_.deallocate(workspace_, buffer_size_); + } + buffer_size_ = size; + workspace_ = alloc_.allocate(size); + } + } + + // Descriptor accessors + rocsparse_spmat_descr a_descriptor() const { + return a_descr_; + } + rocsparse_dnvec_descr b_descriptor() const { + return b_descr_; + } + rocsparse_dnvec_descr c_descriptor() const { + return c_descr_; + } + + // Descriptor setters + void set_a_descriptor(rocsparse_spmat_descr descr) { + if (a_descr_) { + rocsparse_destroy_spmat_descr(a_descr_); + } + a_descr_ = descr; + } + + void set_b_descriptor(rocsparse_dnvec_descr descr) { + if (b_descr_) { + rocsparse_destroy_dnvec_descr(b_descr_); + } + b_descr_ = descr; + } + + void set_c_descriptor(rocsparse_dnvec_descr descr) { + if (c_descr_) { + rocsparse_destroy_dnvec_descr(c_descr_); + } + c_descr_ = descr; + } + +private: + rocsparse::hip_allocator alloc_; + size_t buffer_size_; + char* workspace_; + + // Descriptors + rocsparse_spmat_descr a_descr_; + rocsparse_dnvec_descr b_descr_; + rocsparse_dnvec_descr c_descr_; +}; + +} // namespace __rocsparse +} // namespace spblas diff --git a/include/spblas/vendor/rocsparse/multiply.hpp b/include/spblas/vendor/rocsparse/multiply.hpp index f276fbf..c2fcf00 100644 --- a/include/spblas/vendor/rocsparse/multiply.hpp +++ b/include/spblas/vendor/rocsparse/multiply.hpp @@ -1,125 +1,3 @@ #pragma once -#include -#include -#include - -#include -#include - -#include -#include - -#include "exception.hpp" -#include "hip_allocator.hpp" -#include "types.hpp" - -namespace spblas { - -class spmv_state_t { -public: - spmv_state_t() : spmv_state_t(rocsparse::hip_allocator{}) {} - - spmv_state_t(rocsparse::hip_allocator alloc) - : alloc_(alloc), buffer_size_(0), workspace_(nullptr) { - rocsparse_handle handle; - __rocsparse::throw_if_error(rocsparse_create_handle(&handle)); - if (auto stream = alloc.stream()) { - rocsparse_set_stream(handle, stream); - } - handle_ = handle_manager(handle, [](rocsparse_handle handle) { - __rocsparse::throw_if_error(rocsparse_destroy_handle(handle)); - }); - } - - spmv_state_t(rocsparse::hip_allocator alloc, rocsparse_handle handle) - : alloc_(alloc), buffer_size_(0), workspace_(nullptr) { - handle_ = handle_manager(handle, [](rocsparse_handle handle) { - // it is provided by user, we do not delete it at all. - }); - } - - ~spmv_state_t() { - alloc_.deallocate(workspace_, buffer_size_); - } - - template - requires __detail::has_csr_base && - __detail::has_contiguous_range_base && - __ranges::contiguous_range - void multiply(A&& a, B&& b, C&& c) { - auto a_base = __detail::get_ultimate_base(a); - auto b_base = __detail::get_ultimate_base(b); - using matrix_type = decltype(a_base); - using input_type = decltype(b_base); - using output_type = std::remove_reference_t; - using value_type = typename matrix_type::scalar_type; - - auto alpha_optional = __detail::get_scaling_factor(a, b); - tensor_scalar_t alpha = alpha_optional.value_or(1); - auto handle = this->handle_.get(); - - rocsparse_spmat_descr mat; - __rocsparse::throw_if_error(rocsparse_create_csr_descr( - &mat, __backend::shape(a_base)[0], __backend::shape(a_base)[1], - a_base.values().size(), a_base.rowptr().data(), a_base.colind().data(), - a_base.values().data(), - to_rocsparse_indextype(), - to_rocsparse_indextype(), - rocsparse_index_base_zero, to_rocsparse_datatype())); - rocsparse_dnvec_descr vecb; - rocsparse_dnvec_descr vecc; - __rocsparse::throw_if_error(rocsparse_create_dnvec_descr( - &vecb, b_base.size(), b_base.data(), - to_rocsparse_datatype())); - __rocsparse::throw_if_error(rocsparse_create_dnvec_descr( - &vecc, c.size(), c.data(), - to_rocsparse_datatype())); - value_type alpha_val = alpha; - value_type beta = 0.0; - long unsigned int buffer_size = 0; - // TODO: create a compute type for mixed precision computation - __rocsparse::throw_if_error(rocsparse_spmv( - handle, rocsparse_operation_none, &alpha_val, mat, vecb, &beta, vecc, - to_rocsparse_datatype(), rocsparse_spmv_alg_csr_stream, - rocsparse_spmv_stage_buffer_size, &buffer_size, nullptr)); - // only allocate the new workspace when the requiring workspace larger than - // current - if (buffer_size > this->buffer_size_) { - this->alloc_.deallocate(this->workspace_, buffer_size_); - this->buffer_size_ = buffer_size; - this->workspace_ = this->alloc_.allocate(buffer_size); - } - __rocsparse::throw_if_error(rocsparse_spmv( - handle, rocsparse_operation_none, &alpha_val, mat, vecb, &beta, vecc, - to_rocsparse_datatype(), rocsparse_spmv_alg_csr_stream, - rocsparse_spmv_stage_preprocess, &this->buffer_size_, - this->workspace_)); - __rocsparse::throw_if_error(rocsparse_spmv( - handle, rocsparse_operation_none, &alpha_val, mat, vecb, &beta, vecc, - to_rocsparse_datatype(), rocsparse_spmv_alg_csr_stream, - rocsparse_spmv_stage_compute, &this->buffer_size_, this->workspace_)); - __rocsparse::throw_if_error(rocsparse_destroy_spmat_descr(mat)); - __rocsparse::throw_if_error(rocsparse_destroy_dnvec_descr(vecc)); - __rocsparse::throw_if_error(rocsparse_destroy_dnvec_descr(vecb)); - } - -private: - using handle_manager = - std::unique_ptr::element_type, - std::function>; - handle_manager handle_; - rocsparse::hip_allocator alloc_; - long unsigned int buffer_size_; - char* workspace_; -}; - -template - requires __detail::has_csr_base && - __detail::has_contiguous_range_base && - __ranges::contiguous_range -void multiply(spmv_state_t& spmv_state, A&& a, B&& b, C&& c) { - spmv_state.multiply(a, b, c); -} - -} // namespace spblas +#include diff --git a/include/spblas/vendor/rocsparse/operation_state_t.hpp b/include/spblas/vendor/rocsparse/operation_state_t.hpp new file mode 100644 index 0000000..5fd6f27 --- /dev/null +++ b/include/spblas/vendor/rocsparse/operation_state_t.hpp @@ -0,0 +1,39 @@ +#pragma once + +#include "detail/abstract_operation_state.hpp" +#include + +namespace spblas { +namespace __rocsparse { + +class operation_state_t { +public: + operation_state_t() = default; + operation_state_t(std::unique_ptr&& state) + : state_(std::move(state)) {} + + // Move-only + operation_state_t(operation_state_t&&) = default; + operation_state_t& operator=(operation_state_t&&) = default; + + // No copying + operation_state_t(const operation_state_t&) = delete; + operation_state_t& operator=(const operation_state_t&) = delete; + + // Access the underlying state + template + T* get_state() { + return dynamic_cast(state_.get()); + } + + template + const T* get_state() const { + return dynamic_cast(state_.get()); + } + +private: + std::unique_ptr state_; +}; + +} // namespace __rocsparse +} // namespace spblas diff --git a/include/spblas/vendor/rocsparse/type_validation.hpp b/include/spblas/vendor/rocsparse/type_validation.hpp new file mode 100644 index 0000000..151f0a1 --- /dev/null +++ b/include/spblas/vendor/rocsparse/type_validation.hpp @@ -0,0 +1,20 @@ +#pragma once + +#include +#include + +namespace spblas { +namespace detail { + +template +static constexpr bool has_valid_rocsparse_matrix_types_v = + is_valid_rocsparse_scalar_type_v> && + is_valid_rocsparse_index_type_v> && + is_valid_rocsparse_index_type_v>; + +template +static constexpr bool has_valid_rocsparse_vector_types_v = + is_valid_rocsparse_scalar_type_v>; + +} // namespace detail +} // namespace spblas diff --git a/include/spblas/vendor/rocsparse/types.hpp b/include/spblas/vendor/rocsparse/types.hpp index 82b49ec..e46e6b7 100644 --- a/include/spblas/vendor/rocsparse/types.hpp +++ b/include/spblas/vendor/rocsparse/types.hpp @@ -2,76 +2,86 @@ #include #include +#include #include namespace spblas { -using index_t = std::int64_t; -using offset_t = std::int64_t; +using index_t = std::int32_t; +using offset_t = index_t; namespace detail { -/** - * mapping the type to rocsparse_datatype - */ template -struct rocsparse_datatype_traits {}; +constexpr static bool is_valid_rocsparse_scalar_type_v = + std::is_same_v || std::is_same_v || + std::is_floating_point_v; -#define MAP_ROCSPARSE_DATATYPE(_type, _value) \ - template <> \ - struct rocsparse_datatype_traits<_type> { \ - constexpr static rocsparse_datatype value = _value; \ - } - -MAP_ROCSPARSE_DATATYPE(float, rocsparse_datatype_f32_r); -MAP_ROCSPARSE_DATATYPE(double, rocsparse_datatype_f64_r); -MAP_ROCSPARSE_DATATYPE(std::complex, rocsparse_datatype_f32_c); -MAP_ROCSPARSE_DATATYPE(std::complex, rocsparse_datatype_f64_c); +template +constexpr static bool is_valid_rocsparse_index_type_v = + std::is_same_v || std::is_same_v || + std::is_same_v; -#undef MAP_ROCSPARSE_DATATYPE +template +struct rocsparse_data_type; + +template <> +struct rocsparse_data_type { + constexpr static rocsparse_datatype value = rocsparse_datatype_i32_r; +}; + +template <> +struct rocsparse_data_type { + constexpr static rocsparse_datatype value = rocsparse_datatype_u32_r; +}; + +template <> +struct rocsparse_data_type { + constexpr static rocsparse_datatype value = rocsparse_datatype_f32_r; +}; + +template <> +struct rocsparse_data_type { + constexpr static rocsparse_datatype value = rocsparse_datatype_f64_r; +}; + +template <> +struct rocsparse_data_type> { + constexpr static rocsparse_datatype value = rocsparse_datatype_f32_c; +}; + +template <> +struct rocsparse_data_type> { + constexpr static rocsparse_datatype value = rocsparse_datatype_f64_c; +}; -/** - * mapping the type to rocsparse_indextype - */ template -struct rocsparse_indextype_traits {}; +constexpr static rocsparse_datatype rocsparse_data_type_v = + rocsparse_data_type::value; -#define MAP_ROCSPARSE_INDEXTYPE(_type, _value) \ - template <> \ - struct rocsparse_indextype_traits<_type> { \ - constexpr static rocsparse_indextype value = _value; \ - } +template +struct rocsparse_index_type; -MAP_ROCSPARSE_INDEXTYPE(std::int32_t, rocsparse_indextype_i32); -MAP_ROCSPARSE_INDEXTYPE(std::int64_t, rocsparse_indextype_i64); +template <> +struct rocsparse_index_type { + constexpr static rocsparse_indextype value = rocsparse_indextype_u16; +}; -#undef MAP_ROCSPARSE_INDEXTYPE +template <> +struct rocsparse_index_type { + constexpr static rocsparse_indextype value = rocsparse_indextype_i32; +}; -} // namespace detail +template <> +struct rocsparse_index_type { + constexpr static rocsparse_indextype value = rocsparse_indextype_i64; +}; -/** - * This is an alias for the `rocsparse_datatype` equivalent of `T`. - * - * @tparam T a type - * - * @returns the actual `rocsparse_datatype` - */ -template -constexpr rocsparse_datatype to_rocsparse_datatype() { - return detail::rocsparse_datatype_traits::value; -} - -/** - * This is an alias for the `rocsparse_indextype` equivalent of `T`. - * - * @tparam T a type - * - * @returns the actual `rocsparse_indextype` - */ template -constexpr rocsparse_indextype to_rocsparse_indextype() { - return detail::rocsparse_indextype_traits::value; -} +constexpr static rocsparse_indextype rocsparse_index_type_v = + rocsparse_index_type::value; + +} // namespace detail } // namespace spblas diff --git a/test/gtest/CMakeLists.txt b/test/gtest/CMakeLists.txt index 97c54d2..05e14d2 100644 --- a/test/gtest/CMakeLists.txt +++ b/test/gtest/CMakeLists.txt @@ -1,32 +1,38 @@ enable_testing() -function(add_device_test file_list) - add_executable(spblas-tests ${${file_list}}) + +set(TEST_SOURCES) + +# CPU tests +if (SPBLAS_CPU_BACKEND) + list(APPEND TEST_SOURCES + spmv_test.cpp + spmm_test.cpp + spgemm_test.cpp + spgemm_csr_csc.cpp + add_test.cpp + transpose_test.cpp + triangular_solve_test.cpp) +endif() + +# GPU tests +if (SPBLAS_GPU_BACKEND) if (ENABLE_ROCSPARSE) - set_source_files_properties(${${file_list}} PROPERTIES LANGUAGE HIP) - elseif (ENABLE_CUSPARSE) - target_link_libraries(spblas-tests Thrust) - else() - message(FATAL_ERROR "Device backend not found.") + set_source_files_properties(device/spmv_test.cpp PROPERTIES LANGUAGE HIP) endif() -endfunction() - -if (NOT SPBLAS_GPU_BACKEND) - add_executable( - spblas-tests - spmv_test.cpp - spmm_test.cpp - spgemm_test.cpp - spgemm_csr_csc.cpp - add_test.cpp - transpose_test.cpp - triangular_solve_test.cpp - ) -else() - set(TEST_SOURCES device/spmv_test.cpp) - add_device_test(TEST_SOURCES) + list(APPEND TEST_SOURCES device/spmv_test.cpp) endif() +add_executable(spblas-tests ${TEST_SOURCES}) target_link_libraries(spblas-tests spblas fmt GTest::gtest_main) +# Backend-specific test configuration +if (ENABLE_ROCSPARSE) + target_link_libraries(spblas-tests roc::rocthrust) +elseif (ENABLE_CUSPARSE) + target_link_libraries(spblas-tests Thrust) +elseif (ENABLE_ONEMKL_SYCL) + target_link_libraries(spblas-tests sycl_thrust) +endif() + include(GoogleTest) gtest_discover_tests(spblas-tests) diff --git a/test/gtest/device/spmv_test.cpp b/test/gtest/device/spmv_test.cpp index cf28f89..e3434ae 100644 --- a/test/gtest/device/spmv_test.cpp +++ b/test/gtest/device/spmv_test.cpp @@ -1,4 +1,3 @@ - #include "../util.hpp" #include @@ -9,7 +8,7 @@ using value_t = float; using index_t = spblas::index_t; using offset_t = spblas::offset_t; -TEST(CsrView, SpMV) { +TEST(thrust_CsrView, SpMV) { for (auto&& [num_rows, num_cols, nnz] : util::dims) { auto [values, rowptr, colind, shape, _] = spblas::generate_csr(num_rows, num_cols, @@ -32,8 +31,7 @@ TEST(CsrView, SpMV) { std::span b_span(d_b.data().get(), num_cols); std::span c_span(d_c.data().get(), num_rows); - spblas::spmv_state_t state; - spblas::multiply(state, a, b_span, c_span); + spblas::multiply(a, b_span, c_span); thrust::copy(d_c.begin(), d_c.end(), c.begin()); @@ -53,7 +51,7 @@ TEST(CsrView, SpMV) { } } -TEST(CsrView, SpMV_Ascaled) { +TEST(thrust_CsrView, SpMV_Ascaled) { for (auto&& [num_rows, num_cols, nnz] : {std::tuple(1000, 100, 100), std::tuple(100, 1000, 10000), std::tuple(40, 40, 1000)}) { @@ -79,8 +77,7 @@ TEST(CsrView, SpMV_Ascaled) { std::span b_span(d_b.data().get(), num_cols); std::span c_span(d_c.data().get(), num_rows); - spblas::spmv_state_t state; - spblas::multiply(state, spblas::scaled(alpha, a), b_span, c_span); + spblas::multiply(spblas::scaled(alpha, a), b_span, c_span); thrust::copy(d_c.begin(), d_c.end(), c.begin()); @@ -101,7 +98,7 @@ TEST(CsrView, SpMV_Ascaled) { } } -TEST(CsrView, SpMV_BScaled) { +TEST(thrust_CsrView, SpMV_BScaled) { for (auto&& [num_rows, num_cols, nnz] : {std::tuple(1000, 100, 100), std::tuple(100, 1000, 10000), std::tuple(40, 40, 1000)}) { @@ -127,8 +124,7 @@ TEST(CsrView, SpMV_BScaled) { std::span b_span(d_b.data().get(), num_cols); std::span c_span(d_c.data().get(), num_rows); - spblas::spmv_state_t state; - spblas::multiply(state, a, spblas::scaled(alpha, b_span), c_span); + spblas::multiply(a, spblas::scaled(alpha, b_span), c_span); thrust::copy(d_c.begin(), d_c.end(), c.begin());