Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/spblas/algorithms/algorithms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,8 @@
#include <spblas/algorithms/scaled.hpp>
#include <spblas/algorithms/scaled_impl.hpp>

#include <spblas/algorithms/conjugated.hpp>
#include <spblas/algorithms/conjugated_impl.hpp>

#include <spblas/algorithms/transpose.hpp>
#include <spblas/algorithms/transpose_impl.hpp>
13 changes: 13 additions & 0 deletions include/spblas/algorithms/conjugated.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#pragma once

#include <spblas/concepts.hpp>

namespace spblas {

template <matrix M>
auto conjugated(M&& m);

template <vector V>
auto conjugated(V&& v);

} // namespace spblas
44 changes: 44 additions & 0 deletions include/spblas/algorithms/conjugated_impl.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#pragma once

#include <complex>
#include <type_traits>
#include <utility>

#include <spblas/concepts.hpp>
#include <spblas/views/conjugated_view.hpp>

namespace spblas {

namespace __detail {

template <typename T>
struct is_std_complex : std::false_type {};

template <typename T>
struct is_std_complex<std::complex<T>> : std::true_type {};

template <typename T>
inline constexpr bool is_std_complex_v =
is_std_complex<std::remove_cvref_t<T>>::value;

} // namespace __detail

template <vector V>
auto conjugated(V&& v) {
if constexpr (__detail::is_std_complex_v<tensor_scalar_t<V>>) {
return conjugated_view(std::forward<V>(v));
} else {
return std::forward<V>(v);
}
}

template <matrix M>
auto conjugated(M&& m) {
if constexpr (__detail::is_std_complex_v<tensor_scalar_t<M>>) {
return conjugated_view(std::forward<M>(m));
} else {
return std::forward<M>(m);
}
}

} // namespace spblas
1 change: 1 addition & 0 deletions include/spblas/concepts.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ namespace spblas {
- Instantiations of csc_view<...>
- Instantiations of mdspan<...> with rank 2
- Instantiations of scaled_view<T> where M is a matrix
- Instantiations of conjugated_view<T> where M is a matrix
*/

template <typename M>
Expand Down
22 changes: 22 additions & 0 deletions include/spblas/detail/view_inspectors.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,28 @@ auto get_scaling_factor(T&& t, U&& u) {
}
}

// Inspect a tensor: does it have a conjugation view? Returns true if any
// conjugated_view appears in the chain of bases.
template <tensor T>
bool is_conjugated(T&& t) {
if constexpr (has_base<T>) {
if constexpr (is_conjugated_view_v<T>) {
return true;
}
return is_conjugated(t.base());
} else {
if constexpr (is_conjugated_view_v<T>) {
return true;
}
return false;
}
}

template <tensor T, tensor U>
bool is_conjugated(T&& t, U&& u) {
return is_conjugated(std::forward<T>(t)) || is_conjugated(std::forward<U>(u));
}

template <tensor T>
bool has_scaling_factor(T&& t) {
return get_scaling_factor(t).has_value();
Expand Down
13 changes: 13 additions & 0 deletions include/spblas/vendor/aoclsparse/spgemm_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include <aoclsparse.h>
#include <cstdint>
#include <stdexcept>

#include "aocl_wrappers.hpp"
#include "detail/detail.hpp"
Expand Down Expand Up @@ -40,6 +41,12 @@ operation_info_t multiply_compute(A&& a, B&& b, C&& c) {
auto a_base = __detail::get_ultimate_base(a);
auto b_base = __detail::get_ultimate_base(b);

if (__detail::is_conjugated(a) || __detail::is_conjugated(b) ||
__detail::is_conjugated(c)) {
throw std::runtime_error(
"aoclsparse backend does not support conjugated views.");
}

using T = tensor_scalar_t<C>;
using I = tensor_index_t<C>;
using O = tensor_offset_t<C>;
Expand Down Expand Up @@ -111,6 +118,12 @@ void multiply_fill(operation_info_t& info, A&& a, B&& b, C&& c) {
using I = tensor_index_t<C>;
using O = tensor_offset_t<C>;

if (__detail::is_conjugated(a) || __detail::is_conjugated(b) ||
__detail::is_conjugated(c)) {
throw std::runtime_error(
"aoclsparse backend does not support conjugated views.");
}

auto alpha_optional = __detail::get_scaling_factor(a, b);
tensor_scalar_t<A> alpha = alpha_optional.value_or(1);

Expand Down
7 changes: 7 additions & 0 deletions include/spblas/vendor/aoclsparse/spmm_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include "aoclsparse.h"
#include <cstdint>
#include <stdexcept>

#include "aocl_wrappers.hpp"
#include <fmt/core.h>
Expand Down Expand Up @@ -44,6 +45,12 @@ void multiply(A&& a, X&& x, Y&& y) {
auto x_base = __detail::get_ultimate_base(x);
auto y_base = __detail::get_ultimate_base(y);

if (__detail::is_conjugated(a) || __detail::is_conjugated(x) ||
__detail::is_conjugated(y)) {
throw std::runtime_error(
"aoclsparse backend does not support conjugated views.");
}

aoclsparse_matrix csrA = __aoclsparse::create_matrix_handle(a_base);

aoclsparse_operation opA = __aoclsparse::get_transpose(a);
Expand Down
7 changes: 7 additions & 0 deletions include/spblas/vendor/aoclsparse/spmv_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include "aoclsparse.h"
#include <cstdint>
#include <stdexcept>

#include "aocl_wrappers.hpp"
#include <fmt/core.h>
Expand Down Expand Up @@ -38,6 +39,12 @@ void multiply(A&& a, X&& x, Y&& y) {
auto a_base = __detail::get_ultimate_base(a);
auto x_base = __detail::get_ultimate_base(x);

if (__detail::is_conjugated(a) || __detail::is_conjugated(x) ||
__detail::is_conjugated(y)) {
throw std::runtime_error(
"aoclsparse backend does not support conjugated views.");
}

aoclsparse_matrix csrA = __aoclsparse::create_matrix_handle(a_base);
aoclsparse_operation opA = __aoclsparse::get_transpose(a);

Expand Down
7 changes: 7 additions & 0 deletions include/spblas/vendor/aoclsparse/triangular_solve_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include "aoclsparse.h"
#include <cstdint>
#include <stdexcept>

#include "aocl_wrappers.hpp"
#include <fmt/core.h>
Expand Down Expand Up @@ -46,6 +47,12 @@ void triangular_solve(A&& a, Triangle uplo, DiagonalStorage diag, B&& b,
auto a_base = __detail::get_ultimate_base(a);
auto b_base = __detail::get_ultimate_base(b);

if (__detail::is_conjugated(a) || __detail::is_conjugated(b) ||
__detail::is_conjugated(x)) {
throw std::runtime_error(
"aoclsparse backend does not support conjugated views.");
}

using T = tensor_scalar_t<A>;
using I = tensor_index_t<A>;
using O = tensor_offset_t<A>;
Expand Down
26 changes: 26 additions & 0 deletions include/spblas/vendor/armpl/multiply_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

#include <spblas/vendor/armpl/detail/detail.hpp>

#include <stdexcept>

#include <spblas/detail/log.hpp>
#include <spblas/detail/operation_info_t.hpp>
#include <spblas/detail/ranges.hpp>
Expand All @@ -21,6 +23,12 @@ void multiply(A&& a, B&& b, C&& c) {
auto a_base = __detail::get_ultimate_base(a);
auto b_base = __detail::get_ultimate_base(b);

if (__detail::is_conjugated(a) || __detail::is_conjugated(b) ||
__detail::is_conjugated(c)) {
throw std::runtime_error(
"armpl backend does not support conjugated views.");
}

auto alpha_optional = __detail::get_scaling_factor(a, b);
tensor_scalar_t<A> alpha = alpha_optional.value_or(1);

Expand All @@ -47,6 +55,12 @@ void multiply(A&& a, B&& b, C&& c) {
auto a_base = __detail::get_ultimate_base(a);
auto b_base = __detail::get_ultimate_base(b);

if (__detail::is_conjugated(a) || __detail::is_conjugated(b) ||
__detail::is_conjugated(c)) {
throw std::runtime_error(
"armpl backend does not support conjugated views.");
}

auto alpha_optional = __detail::get_scaling_factor(a, b);
tensor_scalar_t<A> alpha = alpha_optional.value_or(1);

Expand Down Expand Up @@ -92,6 +106,12 @@ operation_info_t multiply_compute(A&& a, B&& b, C&& c) {
auto a_base = __detail::get_ultimate_base(a);
auto b_base = __detail::get_ultimate_base(b);

if (__detail::is_conjugated(a) || __detail::is_conjugated(b) ||
__detail::is_conjugated(c)) {
throw std::runtime_error(
"armpl backend does not support conjugated views.");
}

auto alpha_optional = __detail::get_scaling_factor(a, b);
tensor_scalar_t<A> alpha = alpha_optional.value_or(1);

Expand Down Expand Up @@ -121,6 +141,12 @@ void multiply_fill(operation_info_t& info, A&& a, B&& b, C&& c) {
log_trace("");
auto c_handle = info.state_.c_handle;

if (__detail::is_conjugated(a) || __detail::is_conjugated(b) ||
__detail::is_conjugated(c)) {
throw std::runtime_error(
"armpl backend does not support conjugated views.");
}

__armpl::export_matrix_handle(info, c, c_handle);
}

Expand Down
8 changes: 8 additions & 0 deletions include/spblas/vendor/armpl/triangular_solve_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

#include <spblas/vendor/armpl/detail/armpl.hpp>

#include <stdexcept>

#include <spblas/detail/log.hpp>
#include <spblas/detail/operation_info_t.hpp>
#include <spblas/detail/ranges.hpp>
Expand All @@ -26,6 +28,12 @@ void triangular_solve(A&& a, Triangle uplo, DiagonalStorage diag, B&& b,
auto a_base = __detail::get_ultimate_base(a);
auto b_base = __detail::get_ultimate_base(b);

if (__detail::is_conjugated(a) || __detail::is_conjugated(b) ||
__detail::is_conjugated(x)) {
throw std::runtime_error(
"armpl backend does not support conjugated views.");
}

using T = tensor_scalar_t<A>;
using I = tensor_index_t<A>;
using O = tensor_offset_t<A>;
Expand Down
8 changes: 8 additions & 0 deletions include/spblas/vendor/cusparse/spmv_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

#include <cusparse.h>

#include <stdexcept>

#include <spblas/detail/log.hpp>
#include <spblas/detail/operation_info_t.hpp>
#include <spblas/detail/ranges.hpp>
Expand All @@ -27,6 +29,12 @@ void multiply(operation_info_t& info, A&& a, X&& x, Y&& y) {
auto x_base = __detail::get_ultimate_base(x);
auto a_base = __detail::get_ultimate_base(a);

if (__detail::is_conjugated(a) || __detail::is_conjugated(x) ||
__detail::is_conjugated(y)) {
throw std::runtime_error(
"cusparse backend does not support conjugated views.");
}

auto alpha_optional = __detail::get_scaling_factor(a, x);
tensor_scalar_t<A> alpha = alpha_optional.value_or(1);
tensor_scalar_t<A> beta = 0;
Expand Down
13 changes: 10 additions & 3 deletions include/spblas/vendor/onemkl_sycl/detail/create_matrix_handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

#include <oneapi/mkl.hpp>

#include <stdexcept>

#include <spblas/detail/view_inspectors.hpp>

namespace spblas {
Expand Down Expand Up @@ -57,14 +59,19 @@ oneapi::mkl::sparse::matrix_handle_t create_matrix_handle(sycl::queue& q,
// CSC_transpose -> CSR + nontrans
//
template <matrix M>
oneapi::mkl::transpose get_transpose(M&& m) {
oneapi::mkl::transpose get_transpose(M&& m, bool conjugate = false) {
static_assert(__detail::has_csr_base<M> || __detail::has_csc_base<M>);
if constexpr (__detail::has_base<M>) {
return get_transpose(m.base());
return get_transpose(m.base(), conjugate);
} else if constexpr (__detail::is_csr_view_v<M>) {
if (conjugate) {
throw std::runtime_error(
"oneMKL SYCL backend does not support conjugation for CSR views.");
}
return oneapi::mkl::transpose::nontrans;
} else if constexpr (__detail::is_csc_view_v<M>) {
return oneapi::mkl::transpose::trans;
return conjugate ? oneapi::mkl::transpose::conjtrans
: oneapi::mkl::transpose::trans;
}
}

Expand Down
20 changes: 18 additions & 2 deletions include/spblas/vendor/onemkl_sycl/spgemm_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

#include <oneapi/mkl.hpp>

#include <stdexcept>

#include <spblas/detail/log.hpp>

#include <spblas/algorithms/transposed.hpp>
Expand Down Expand Up @@ -39,6 +41,11 @@ operation_info_t
using oneapi::mkl::sparse::matmat_request;
using oneapi::mkl::sparse::matrix_view_descr;

if (__detail::is_conjugated(c)) {
throw std::runtime_error(
"oneMKL SYCL backend does not support conjugated output matrices.");
}

auto a_data = __detail::get_ultimate_base(a).values().data();
auto&& q = __mkl::get_queue(policy, a_data);

Expand Down Expand Up @@ -69,8 +76,9 @@ operation_info_t

oneapi::mkl::sparse::set_matmat_data(
descr, matrix_view_descr::general,
__mkl::get_transpose(a), // view/op for A
matrix_view_descr::general, __mkl::get_transpose(b), // view/op for B
__mkl::get_transpose(a, __detail::is_conjugated(a)), // view/op for A
matrix_view_descr::general,
__mkl::get_transpose(b, __detail::is_conjugated(b)), // view/op for B
matrix_view_descr::general); // view for C

auto ev1 = oneapi::mkl::sparse::matmat(q, a_handle, b_handle, c_handle,
Expand Down Expand Up @@ -119,6 +127,14 @@ void multiply_fill(ExecutionPolicy&& policy, operation_info_t& info, A&& a,
B&& b, C&& c) {
log_trace("");

if (__detail::is_conjugated(c)) {
throw std::runtime_error(
"oneMKL SYCL backend does not support conjugated output matrices.");
}

(void) __mkl::get_transpose(a, __detail::is_conjugated(a));
(void) __mkl::get_transpose(b, __detail::is_conjugated(b));

auto alpha_optional = __detail::get_scaling_factor(a, b);
tensor_scalar_t<A> alpha = alpha_optional.value_or(1);

Expand Down
Loading
Loading