From b2a7bb1312d36cdb8634f9f2d793ff48d504fbed Mon Sep 17 00:00:00 2001 From: Benjamin Brock Date: Mon, 12 Jan 2026 15:49:34 -0800 Subject: [PATCH] First draft at adding support for `conjugate`. --- include/spblas/algorithms/algorithms.hpp | 3 + include/spblas/algorithms/conjugated.hpp | 13 ++ include/spblas/algorithms/conjugated_impl.hpp | 44 ++++ include/spblas/concepts.hpp | 1 + include/spblas/detail/view_inspectors.hpp | 22 ++ .../spblas/vendor/aoclsparse/spgemm_impl.hpp | 13 ++ .../spblas/vendor/aoclsparse/spmm_impl.hpp | 7 + .../spblas/vendor/aoclsparse/spmv_impl.hpp | 7 + .../aoclsparse/triangular_solve_impl.hpp | 7 + include/spblas/vendor/armpl/multiply_impl.hpp | 26 +++ .../vendor/armpl/triangular_solve_impl.hpp | 8 + include/spblas/vendor/cusparse/spmv_impl.hpp | 8 + .../detail/create_matrix_handle.hpp | 13 +- .../spblas/vendor/onemkl_sycl/spgemm_impl.hpp | 20 +- .../spblas/vendor/onemkl_sycl/spmm_impl.hpp | 9 +- .../spblas/vendor/onemkl_sycl/spmv_impl.hpp | 9 +- .../vendor/rocsparse/detail/spmv_impl.hpp | 8 + include/spblas/views/conjugated_view.hpp | 13 ++ include/spblas/views/conjugated_view_impl.hpp | 198 ++++++++++++++++++ include/spblas/views/inspectors.hpp | 19 ++ include/spblas/views/views.hpp | 2 + test/gtest/spmv_test.cpp | 82 ++++++++ 22 files changed, 525 insertions(+), 7 deletions(-) create mode 100644 include/spblas/algorithms/conjugated.hpp create mode 100644 include/spblas/algorithms/conjugated_impl.hpp create mode 100644 include/spblas/views/conjugated_view.hpp create mode 100644 include/spblas/views/conjugated_view_impl.hpp diff --git a/include/spblas/algorithms/algorithms.hpp b/include/spblas/algorithms/algorithms.hpp index b5fb92b..1b4af97 100644 --- a/include/spblas/algorithms/algorithms.hpp +++ b/include/spblas/algorithms/algorithms.hpp @@ -16,5 +16,8 @@ #include #include +#include +#include + #include #include diff --git a/include/spblas/algorithms/conjugated.hpp b/include/spblas/algorithms/conjugated.hpp new file mode 100644 index 0000000..95bb986 --- /dev/null +++ b/include/spblas/algorithms/conjugated.hpp @@ -0,0 +1,13 @@ +#pragma once + +#include + +namespace spblas { + +template +auto conjugated(M&& m); + +template +auto conjugated(V&& v); + +} // namespace spblas diff --git a/include/spblas/algorithms/conjugated_impl.hpp b/include/spblas/algorithms/conjugated_impl.hpp new file mode 100644 index 0000000..bb94c43 --- /dev/null +++ b/include/spblas/algorithms/conjugated_impl.hpp @@ -0,0 +1,44 @@ +#pragma once + +#include +#include +#include + +#include +#include + +namespace spblas { + +namespace __detail { + +template +struct is_std_complex : std::false_type {}; + +template +struct is_std_complex> : std::true_type {}; + +template +inline constexpr bool is_std_complex_v = + is_std_complex>::value; + +} // namespace __detail + +template +auto conjugated(V&& v) { + if constexpr (__detail::is_std_complex_v>) { + return conjugated_view(std::forward(v)); + } else { + return std::forward(v); + } +} + +template +auto conjugated(M&& m) { + if constexpr (__detail::is_std_complex_v>) { + return conjugated_view(std::forward(m)); + } else { + return std::forward(m); + } +} + +} // namespace spblas diff --git a/include/spblas/concepts.hpp b/include/spblas/concepts.hpp index 03ab310..bdf8dbd 100644 --- a/include/spblas/concepts.hpp +++ b/include/spblas/concepts.hpp @@ -14,6 +14,7 @@ namespace spblas { - Instantiations of csc_view<...> - Instantiations of mdspan<...> with rank 2 - Instantiations of scaled_view where M is a matrix + - Instantiations of conjugated_view where M is a matrix */ template diff --git a/include/spblas/detail/view_inspectors.hpp b/include/spblas/detail/view_inspectors.hpp index 7e8ec74..05015d6 100644 --- a/include/spblas/detail/view_inspectors.hpp +++ b/include/spblas/detail/view_inspectors.hpp @@ -76,6 +76,28 @@ auto get_scaling_factor(T&& t, U&& u) { } } +// Inspect a tensor: does it have a conjugation view? Returns true if any +// conjugated_view appears in the chain of bases. +template +bool is_conjugated(T&& t) { + if constexpr (has_base) { + if constexpr (is_conjugated_view_v) { + return true; + } + return is_conjugated(t.base()); + } else { + if constexpr (is_conjugated_view_v) { + return true; + } + return false; + } +} + +template +bool is_conjugated(T&& t, U&& u) { + return is_conjugated(std::forward(t)) || is_conjugated(std::forward(u)); +} + template bool has_scaling_factor(T&& t) { return get_scaling_factor(t).has_value(); diff --git a/include/spblas/vendor/aoclsparse/spgemm_impl.hpp b/include/spblas/vendor/aoclsparse/spgemm_impl.hpp index 1101bcb..7eac9da 100644 --- a/include/spblas/vendor/aoclsparse/spgemm_impl.hpp +++ b/include/spblas/vendor/aoclsparse/spgemm_impl.hpp @@ -11,6 +11,7 @@ #include #include +#include #include "aocl_wrappers.hpp" #include "detail/detail.hpp" @@ -40,6 +41,12 @@ operation_info_t multiply_compute(A&& a, B&& b, C&& c) { auto a_base = __detail::get_ultimate_base(a); auto b_base = __detail::get_ultimate_base(b); + if (__detail::is_conjugated(a) || __detail::is_conjugated(b) || + __detail::is_conjugated(c)) { + throw std::runtime_error( + "aoclsparse backend does not support conjugated views."); + } + using T = tensor_scalar_t; using I = tensor_index_t; using O = tensor_offset_t; @@ -111,6 +118,12 @@ void multiply_fill(operation_info_t& info, A&& a, B&& b, C&& c) { using I = tensor_index_t; using O = tensor_offset_t; + if (__detail::is_conjugated(a) || __detail::is_conjugated(b) || + __detail::is_conjugated(c)) { + throw std::runtime_error( + "aoclsparse backend does not support conjugated views."); + } + auto alpha_optional = __detail::get_scaling_factor(a, b); tensor_scalar_t alpha = alpha_optional.value_or(1); diff --git a/include/spblas/vendor/aoclsparse/spmm_impl.hpp b/include/spblas/vendor/aoclsparse/spmm_impl.hpp index b116b84..65b8f76 100644 --- a/include/spblas/vendor/aoclsparse/spmm_impl.hpp +++ b/include/spblas/vendor/aoclsparse/spmm_impl.hpp @@ -11,6 +11,7 @@ #include "aoclsparse.h" #include +#include #include "aocl_wrappers.hpp" #include @@ -44,6 +45,12 @@ void multiply(A&& a, X&& x, Y&& y) { auto x_base = __detail::get_ultimate_base(x); auto y_base = __detail::get_ultimate_base(y); + if (__detail::is_conjugated(a) || __detail::is_conjugated(x) || + __detail::is_conjugated(y)) { + throw std::runtime_error( + "aoclsparse backend does not support conjugated views."); + } + aoclsparse_matrix csrA = __aoclsparse::create_matrix_handle(a_base); aoclsparse_operation opA = __aoclsparse::get_transpose(a); diff --git a/include/spblas/vendor/aoclsparse/spmv_impl.hpp b/include/spblas/vendor/aoclsparse/spmv_impl.hpp index 9199b09..423e6ed 100644 --- a/include/spblas/vendor/aoclsparse/spmv_impl.hpp +++ b/include/spblas/vendor/aoclsparse/spmv_impl.hpp @@ -11,6 +11,7 @@ #include "aoclsparse.h" #include +#include #include "aocl_wrappers.hpp" #include @@ -38,6 +39,12 @@ void multiply(A&& a, X&& x, Y&& y) { auto a_base = __detail::get_ultimate_base(a); auto x_base = __detail::get_ultimate_base(x); + if (__detail::is_conjugated(a) || __detail::is_conjugated(x) || + __detail::is_conjugated(y)) { + throw std::runtime_error( + "aoclsparse backend does not support conjugated views."); + } + aoclsparse_matrix csrA = __aoclsparse::create_matrix_handle(a_base); aoclsparse_operation opA = __aoclsparse::get_transpose(a); diff --git a/include/spblas/vendor/aoclsparse/triangular_solve_impl.hpp b/include/spblas/vendor/aoclsparse/triangular_solve_impl.hpp index 9ebafc4..cb5417c 100644 --- a/include/spblas/vendor/aoclsparse/triangular_solve_impl.hpp +++ b/include/spblas/vendor/aoclsparse/triangular_solve_impl.hpp @@ -11,6 +11,7 @@ #include "aoclsparse.h" #include +#include #include "aocl_wrappers.hpp" #include @@ -46,6 +47,12 @@ void triangular_solve(A&& a, Triangle uplo, DiagonalStorage diag, B&& b, auto a_base = __detail::get_ultimate_base(a); auto b_base = __detail::get_ultimate_base(b); + if (__detail::is_conjugated(a) || __detail::is_conjugated(b) || + __detail::is_conjugated(x)) { + throw std::runtime_error( + "aoclsparse backend does not support conjugated views."); + } + using T = tensor_scalar_t; using I = tensor_index_t; using O = tensor_offset_t; diff --git a/include/spblas/vendor/armpl/multiply_impl.hpp b/include/spblas/vendor/armpl/multiply_impl.hpp index 9c07279..9a5dd49 100644 --- a/include/spblas/vendor/armpl/multiply_impl.hpp +++ b/include/spblas/vendor/armpl/multiply_impl.hpp @@ -4,6 +4,8 @@ #include +#include + #include #include #include @@ -21,6 +23,12 @@ void multiply(A&& a, B&& b, C&& c) { auto a_base = __detail::get_ultimate_base(a); auto b_base = __detail::get_ultimate_base(b); + if (__detail::is_conjugated(a) || __detail::is_conjugated(b) || + __detail::is_conjugated(c)) { + throw std::runtime_error( + "armpl backend does not support conjugated views."); + } + auto alpha_optional = __detail::get_scaling_factor(a, b); tensor_scalar_t alpha = alpha_optional.value_or(1); @@ -47,6 +55,12 @@ void multiply(A&& a, B&& b, C&& c) { auto a_base = __detail::get_ultimate_base(a); auto b_base = __detail::get_ultimate_base(b); + if (__detail::is_conjugated(a) || __detail::is_conjugated(b) || + __detail::is_conjugated(c)) { + throw std::runtime_error( + "armpl backend does not support conjugated views."); + } + auto alpha_optional = __detail::get_scaling_factor(a, b); tensor_scalar_t alpha = alpha_optional.value_or(1); @@ -92,6 +106,12 @@ operation_info_t multiply_compute(A&& a, B&& b, C&& c) { auto a_base = __detail::get_ultimate_base(a); auto b_base = __detail::get_ultimate_base(b); + if (__detail::is_conjugated(a) || __detail::is_conjugated(b) || + __detail::is_conjugated(c)) { + throw std::runtime_error( + "armpl backend does not support conjugated views."); + } + auto alpha_optional = __detail::get_scaling_factor(a, b); tensor_scalar_t alpha = alpha_optional.value_or(1); @@ -121,6 +141,12 @@ void multiply_fill(operation_info_t& info, A&& a, B&& b, C&& c) { log_trace(""); auto c_handle = info.state_.c_handle; + if (__detail::is_conjugated(a) || __detail::is_conjugated(b) || + __detail::is_conjugated(c)) { + throw std::runtime_error( + "armpl backend does not support conjugated views."); + } + __armpl::export_matrix_handle(info, c, c_handle); } diff --git a/include/spblas/vendor/armpl/triangular_solve_impl.hpp b/include/spblas/vendor/armpl/triangular_solve_impl.hpp index d0ca8ca..2a7a095 100644 --- a/include/spblas/vendor/armpl/triangular_solve_impl.hpp +++ b/include/spblas/vendor/armpl/triangular_solve_impl.hpp @@ -2,6 +2,8 @@ #include +#include + #include #include #include @@ -26,6 +28,12 @@ void triangular_solve(A&& a, Triangle uplo, DiagonalStorage diag, B&& b, auto a_base = __detail::get_ultimate_base(a); auto b_base = __detail::get_ultimate_base(b); + if (__detail::is_conjugated(a) || __detail::is_conjugated(b) || + __detail::is_conjugated(x)) { + throw std::runtime_error( + "armpl backend does not support conjugated views."); + } + using T = tensor_scalar_t; using I = tensor_index_t; using O = tensor_offset_t; diff --git a/include/spblas/vendor/cusparse/spmv_impl.hpp b/include/spblas/vendor/cusparse/spmv_impl.hpp index 33a5825..bcbc68d 100644 --- a/include/spblas/vendor/cusparse/spmv_impl.hpp +++ b/include/spblas/vendor/cusparse/spmv_impl.hpp @@ -2,6 +2,8 @@ #include +#include + #include #include #include @@ -27,6 +29,12 @@ void multiply(operation_info_t& info, A&& a, X&& x, Y&& y) { auto x_base = __detail::get_ultimate_base(x); auto a_base = __detail::get_ultimate_base(a); + if (__detail::is_conjugated(a) || __detail::is_conjugated(x) || + __detail::is_conjugated(y)) { + throw std::runtime_error( + "cusparse backend does not support conjugated views."); + } + auto alpha_optional = __detail::get_scaling_factor(a, x); tensor_scalar_t alpha = alpha_optional.value_or(1); tensor_scalar_t beta = 0; diff --git a/include/spblas/vendor/onemkl_sycl/detail/create_matrix_handle.hpp b/include/spblas/vendor/onemkl_sycl/detail/create_matrix_handle.hpp index 5d7139f..89c60f2 100644 --- a/include/spblas/vendor/onemkl_sycl/detail/create_matrix_handle.hpp +++ b/include/spblas/vendor/onemkl_sycl/detail/create_matrix_handle.hpp @@ -2,6 +2,8 @@ #include +#include + #include namespace spblas { @@ -57,14 +59,19 @@ oneapi::mkl::sparse::matrix_handle_t create_matrix_handle(sycl::queue& q, // CSC_transpose -> CSR + nontrans // template -oneapi::mkl::transpose get_transpose(M&& m) { +oneapi::mkl::transpose get_transpose(M&& m, bool conjugate = false) { static_assert(__detail::has_csr_base || __detail::has_csc_base); if constexpr (__detail::has_base) { - return get_transpose(m.base()); + return get_transpose(m.base(), conjugate); } else if constexpr (__detail::is_csr_view_v) { + if (conjugate) { + throw std::runtime_error( + "oneMKL SYCL backend does not support conjugation for CSR views."); + } return oneapi::mkl::transpose::nontrans; } else if constexpr (__detail::is_csc_view_v) { - return oneapi::mkl::transpose::trans; + return conjugate ? oneapi::mkl::transpose::conjtrans + : oneapi::mkl::transpose::trans; } } diff --git a/include/spblas/vendor/onemkl_sycl/spgemm_impl.hpp b/include/spblas/vendor/onemkl_sycl/spgemm_impl.hpp index 9e67a70..518838d 100644 --- a/include/spblas/vendor/onemkl_sycl/spgemm_impl.hpp +++ b/include/spblas/vendor/onemkl_sycl/spgemm_impl.hpp @@ -2,6 +2,8 @@ #include +#include + #include #include @@ -39,6 +41,11 @@ operation_info_t using oneapi::mkl::sparse::matmat_request; using oneapi::mkl::sparse::matrix_view_descr; + if (__detail::is_conjugated(c)) { + throw std::runtime_error( + "oneMKL SYCL backend does not support conjugated output matrices."); + } + auto a_data = __detail::get_ultimate_base(a).values().data(); auto&& q = __mkl::get_queue(policy, a_data); @@ -69,8 +76,9 @@ operation_info_t oneapi::mkl::sparse::set_matmat_data( descr, matrix_view_descr::general, - __mkl::get_transpose(a), // view/op for A - matrix_view_descr::general, __mkl::get_transpose(b), // view/op for B + __mkl::get_transpose(a, __detail::is_conjugated(a)), // view/op for A + matrix_view_descr::general, + __mkl::get_transpose(b, __detail::is_conjugated(b)), // view/op for B matrix_view_descr::general); // view for C auto ev1 = oneapi::mkl::sparse::matmat(q, a_handle, b_handle, c_handle, @@ -119,6 +127,14 @@ void multiply_fill(ExecutionPolicy&& policy, operation_info_t& info, A&& a, B&& b, C&& c) { log_trace(""); + if (__detail::is_conjugated(c)) { + throw std::runtime_error( + "oneMKL SYCL backend does not support conjugated output matrices."); + } + + (void) __mkl::get_transpose(a, __detail::is_conjugated(a)); + (void) __mkl::get_transpose(b, __detail::is_conjugated(b)); + auto alpha_optional = __detail::get_scaling_factor(a, b); tensor_scalar_t alpha = alpha_optional.value_or(1); diff --git a/include/spblas/vendor/onemkl_sycl/spmm_impl.hpp b/include/spblas/vendor/onemkl_sycl/spmm_impl.hpp index 856e711..2e393b1 100644 --- a/include/spblas/vendor/onemkl_sycl/spmm_impl.hpp +++ b/include/spblas/vendor/onemkl_sycl/spmm_impl.hpp @@ -2,6 +2,8 @@ #include +#include + #include #include #include @@ -39,6 +41,11 @@ void multiply(ExecutionPolicy&& policy, A&& a, X&& x, Y&& y) { log_trace(""); auto x_base = __detail::get_ultimate_base(x); + if (__detail::is_conjugated(x) || __detail::is_conjugated(y)) { + throw std::runtime_error( + "oneMKL SYCL backend does not support conjugated dense matrices."); + } + auto alpha_optional = __detail::get_scaling_factor(a, x); tensor_scalar_t alpha = alpha_optional.value_or(1); @@ -46,7 +53,7 @@ void multiply(ExecutionPolicy&& policy, A&& a, X&& x, Y&& y) { auto&& q = __mkl::get_queue(policy, a_data); auto a_handle = __mkl::get_matrix_handle(q, a); - auto a_transpose = __mkl::get_transpose(a); + auto a_transpose = __mkl::get_transpose(a, __detail::is_conjugated(a)); oneapi::mkl::sparse::gemm(q, oneapi::mkl::layout::row_major, a_transpose, oneapi::mkl::transpose::nontrans, alpha, a_handle, diff --git a/include/spblas/vendor/onemkl_sycl/spmv_impl.hpp b/include/spblas/vendor/onemkl_sycl/spmv_impl.hpp index 33e34b5..acd7eae 100644 --- a/include/spblas/vendor/onemkl_sycl/spmv_impl.hpp +++ b/include/spblas/vendor/onemkl_sycl/spmv_impl.hpp @@ -2,6 +2,8 @@ #include +#include + #include #include #include @@ -34,6 +36,11 @@ void multiply(ExecutionPolicy&& policy, A&& a, X&& x, Y&& y) { log_trace(""); auto x_base = __detail::get_ultimate_base(x); + if (__detail::is_conjugated(x) || __detail::is_conjugated(y)) { + throw std::runtime_error( + "oneMKL SYCL backend does not support conjugated dense vectors."); + } + auto alpha_optional = __detail::get_scaling_factor(a, x); tensor_scalar_t alpha = alpha_optional.value_or(1); @@ -42,7 +49,7 @@ void multiply(ExecutionPolicy&& policy, A&& a, X&& x, Y&& y) { auto&& q = __mkl::get_queue(policy, a_data); auto a_handle = __mkl::get_matrix_handle(q, a); - auto a_transpose = __mkl::get_transpose(a); + auto a_transpose = __mkl::get_transpose(a, __detail::is_conjugated(a)); oneapi::mkl::sparse::gemv(q, a_transpose, alpha, a_handle, __ranges::data(x_base), 0.0, __ranges::data(y)) diff --git a/include/spblas/vendor/rocsparse/detail/spmv_impl.hpp b/include/spblas/vendor/rocsparse/detail/spmv_impl.hpp index 5ade039..daa5f6f 100644 --- a/include/spblas/vendor/rocsparse/detail/spmv_impl.hpp +++ b/include/spblas/vendor/rocsparse/detail/spmv_impl.hpp @@ -2,6 +2,8 @@ #include +#include + #include #include #include @@ -24,6 +26,12 @@ void multiply(operation_info_t& info, A&& a, B&& b, C&& c) { auto a_base = __detail::get_ultimate_base(a); auto b_base = __detail::get_ultimate_base(b); + if (__detail::is_conjugated(a) || __detail::is_conjugated(b) || + __detail::is_conjugated(c)) { + throw std::runtime_error( + "rocsparse backend does not support conjugated views."); + } + auto alpha_optional = __detail::get_scaling_factor(a, b); tensor_scalar_t alpha = alpha_optional.value_or(1); tensor_scalar_t beta = 0; diff --git a/include/spblas/views/conjugated_view.hpp b/include/spblas/views/conjugated_view.hpp new file mode 100644 index 0000000..237c268 --- /dev/null +++ b/include/spblas/views/conjugated_view.hpp @@ -0,0 +1,13 @@ +#pragma once + +#include +#include +#include + +namespace spblas { + +// Conjugate a tensor of type `T`. +template +class conjugated_view; + +} // namespace spblas diff --git a/include/spblas/views/conjugated_view_impl.hpp b/include/spblas/views/conjugated_view_impl.hpp new file mode 100644 index 0000000..6036b1e --- /dev/null +++ b/include/spblas/views/conjugated_view_impl.hpp @@ -0,0 +1,198 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include +#include + +namespace spblas { + +// Conjugate a tensor of type `T`. +template +class conjugated_view; + +// Conjugated view for random access range +template + requires(__detail::__ranges::view && __ranges::random_access_range) +class conjugated_view : public view_base { +public: + using scalar_type = decltype(std::conj(std::declval>())); + using scalar_reference = scalar_type; + using index_type = tensor_index_t; + using offset_type = tensor_offset_t; + + explicit conjugated_view(V vector) + : vector_(vector), transform_view_(vector, transform_fn_{}) {} + + index_type shape() const noexcept { + return __backend::shape(base()); + } + + index_type size() const noexcept { + return __backend::size(base()); + } + + scalar_type operator[](index_type i) const { + return transform_view_[i]; + } + + auto base() { + return vector_; + } + + auto base() const { + return vector_; + } + + auto begin() { + return transform_view_.begin(); + } + + auto begin() const { + return transform_view_.begin(); + } + + auto end() { + return transform_view_.end(); + } + + auto end() const { + return transform_view_.end(); + } + +private: + struct transform_fn_ { + auto operator()(auto x) { + return std::conj(x); + } + }; + + __ranges::transform_view transform_view_; + +private: + V vector_; +}; + +template <__ranges::random_access_range R> +conjugated_view(R&& r) -> conjugated_view<__ranges::views::all_t>; + +// Conjugated view for matrices +template + requires(view) +class conjugated_view : public view_base { +public: + using scalar_type = decltype(std::conj(std::declval>())); + using scalar_reference = scalar_type; + using index_type = tensor_index_t; + using offset_type = tensor_offset_t; + + explicit conjugated_view(M matrix) : matrix_(matrix) {} + + auto shape() const noexcept { + return __backend::shape(base()); + } + + index_type size() const noexcept { + return __backend::size(base()); + } + + auto base() { + return matrix_; + } + + auto base() const { + return matrix_; + } + +private: + friend auto tag_invoke(__backend::size_fn_, conjugated_view matrix) { + return __backend::size(matrix.base()); + } + + friend auto tag_invoke(__backend::shape_fn_, conjugated_view matrix) { + return __backend::shape(matrix.base()); + } + + friend auto tag_invoke(__backend::lookup_fn_, conjugated_view matrix, + index_type i, index_type j) + requires(__backend::lookupable) + { + return std::conj(__backend::lookup(matrix.base(), i, j)); + } + + friend auto tag_invoke(__backend::rows_fn_, conjugated_view matrix) + requires(__backend::row_iterable) + { + auto unscaled_rows = __backend::rows(matrix.base()); + + return unscaled_rows | + __ranges::views::transform([](auto&& row_tuple) { + auto&& [column_index, row] = row_tuple; + + auto conjugated_row = + row | __ranges::views::transform([](auto&& element_tuple) { + auto&& [column_index, value] = element_tuple; + return std::pair(column_index, std::conj(value)); + }); + + return std::pair(column_index, conjugated_row); + }); + } + + friend auto tag_invoke(__backend::lookup_row_fn_, conjugated_view matrix, + index_type row_index) + requires(__backend::row_lookupable) + { + auto unscaled_row = __backend::lookup_row(matrix.base(), row_index); + + return unscaled_row | __ranges::views::transform([](auto&& element_tuple) { + auto&& [column_index, value] = element_tuple; + return std::pair(column_index, std::conj(value)); + }); + } + + friend auto tag_invoke(__backend::columns_fn_, conjugated_view matrix) + requires(__backend::column_iterable) + { + auto unscaled_columns = __backend::columns(matrix.base()); + + return unscaled_columns | + __ranges::views::transform([](auto&& column_tuple) { + auto&& [row_index, column] = column_tuple; + + auto conjugated_column = + column | __ranges::views::transform([](auto&& element_tuple) { + auto&& [row_index, value] = element_tuple; + return std::pair(row_index, std::conj(value)); + }); + + return std::pair(row_index, conjugated_column); + }); + } + + friend auto tag_invoke(__backend::lookup_column_fn_, conjugated_view matrix, + index_type column_index) + requires(__backend::column_lookupable) + { + auto unscaled_column = + __backend::lookup_column(matrix.base(), column_index); + + return unscaled_column | __ranges::views::transform([](auto&& element_tuple) { + auto&& [row_index, value] = element_tuple; + return std::pair(row_index, std::conj(value)); + }); + } + +private: + M matrix_; +}; + +template + requires(view) +conjugated_view(M m) -> conjugated_view; + +} // namespace spblas diff --git a/include/spblas/views/inspectors.hpp b/include/spblas/views/inspectors.hpp index 99e7ec6..157dcb0 100644 --- a/include/spblas/views/inspectors.hpp +++ b/include/spblas/views/inspectors.hpp @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include #include @@ -79,6 +80,24 @@ template static constexpr bool is_scaled_view_matrix_v = is_scaled_view_v && matrix().base())>; +template +struct is_instantiation_of_conjugated_view { + static constexpr bool value = false; +}; + +template +struct is_instantiation_of_conjugated_view> { + static constexpr bool value = true; +}; + +template +static constexpr bool is_conjugated_view_v = + is_instantiation_of_conjugated_view>::value; + +template +static constexpr bool is_conjugated_view_matrix_v = + is_conjugated_view_v && matrix().base())>; + template struct is_instantiation_of_matrix_opt { static constexpr bool value = false; diff --git a/include/spblas/views/views.hpp b/include/spblas/views/views.hpp index 4a180eb..b4d594a 100644 --- a/include/spblas/views/views.hpp +++ b/include/spblas/views/views.hpp @@ -1,5 +1,7 @@ #pragma once +#include +#include #include #include #include diff --git a/test/gtest/spmv_test.cpp b/test/gtest/spmv_test.cpp index c42546a..e48f789 100644 --- a/test/gtest/spmv_test.cpp +++ b/test/gtest/spmv_test.cpp @@ -1,5 +1,7 @@ #include +#include + #include "util.hpp" #include @@ -139,6 +141,86 @@ TEST(CscView, SpMV) { } } +#ifndef SPBLAS_VENDOR_BACKEND + +TEST(CsrView, SpMV_Aconjugated) { + using T = std::complex; + using I = spblas::index_t; + + for (auto&& [m, n, nnz] : util::dims) { + auto [values_real, rowptr, colind, shape, _] = + spblas::generate_csr(m, n, nnz); + + std::vector values(values_real.size()); + for (std::size_t i = 0; i < values.size(); i++) { + values[i] = T(values_real[i], static_cast((i % 7) + 1)); + } + + spblas::csr_view a(values, rowptr, colind, shape, nnz); + + std::vector b(n, T(1.0f, -2.0f)); + std::vector c(m, T(0.0f, 0.0f)); + + spblas::multiply(spblas::conjugated(a), b, c); + + std::vector c_ref(m, T(0.0f, 0.0f)); + + for (I i = 0; i < m; i++) { + for (I j_ptr = rowptr[i]; j_ptr < rowptr[i + 1]; j_ptr++) { + I j = colind[j_ptr]; + T v = values[j_ptr]; + + c_ref[i] += std::conj(v) * b[j]; + } + } + + for (I i = 0; i < c_ref.size(); i++) { + EXPECT_EQ_(c_ref[i].real(), c[i].real()); + EXPECT_EQ_(c_ref[i].imag(), c[i].imag()); + } + } +} + +TEST(CsrView, SpMV_Bconjugated) { + using T = std::complex; + using I = spblas::index_t; + + for (auto&& [m, n, nnz] : util::dims) { + auto [values_real, rowptr, colind, shape, _] = + spblas::generate_csr(m, n, nnz); + + std::vector values(values_real.size()); + for (std::size_t i = 0; i < values.size(); i++) { + values[i] = T(values_real[i], static_cast((i % 5) + 1)); + } + + spblas::csr_view a(values, rowptr, colind, shape, nnz); + + std::vector b(n, T(1.0f, -2.0f)); + std::vector c(m, T(0.0f, 0.0f)); + + spblas::multiply(a, spblas::conjugated(b), c); + + std::vector c_ref(m, T(0.0f, 0.0f)); + + for (I i = 0; i < m; i++) { + for (I j_ptr = rowptr[i]; j_ptr < rowptr[i + 1]; j_ptr++) { + I j = colind[j_ptr]; + T v = values[j_ptr]; + + c_ref[i] += v * std::conj(b[j]); + } + } + + for (I i = 0; i < c_ref.size(); i++) { + EXPECT_EQ_(c_ref[i].real(), c[i].real()); + EXPECT_EQ_(c_ref[i].imag(), c[i].imag()); + } + } +} + +#endif + TEST(CscView, SpMV_Ascaled) { using T = float; using I = spblas::index_t;