diff --git a/src/implementations/eig.jl b/src/implementations/eig.jl index 9b785898..01e4eccd 100644 --- a/src/implementations/eig.jl +++ b/src/implementations/eig.jl @@ -9,8 +9,7 @@ copy_input(::Union{typeof(eig_trunc), typeof(eig_trunc_no_error)}, A) = copy_inp copy_input(::typeof(eig_full), A::Diagonal) = copy(A) function check_input(::typeof(eig_full!), A::AbstractMatrix, DV, ::AbstractAlgorithm) - m, n = size(A) - m == n || throw(DimensionMismatch("square input matrix expected")) + m = LinearAlgebra.checksquare(A) D, V = DV @assert D isa Diagonal && V isa AbstractMatrix @check_size(D, (m, m)) @@ -20,17 +19,16 @@ function check_input(::typeof(eig_full!), A::AbstractMatrix, DV, ::AbstractAlgor return nothing end function check_input(::typeof(eig_vals!), A::AbstractMatrix, D, ::AbstractAlgorithm) - m, n = size(A) - m == n || throw(DimensionMismatch("square input matrix expected")) + m = LinearAlgebra.checksquare(A) @assert D isa AbstractVector - @check_size(D, (n,)) + @check_size(D, (m,)) @check_scalar(D, A, complex) return nothing end function check_input(::typeof(eig_full!), A::AbstractMatrix, DV, ::DiagonalAlgorithm) - m, n = size(A) - ((m == n) && isdiag(A)) || throw(DimensionMismatch("diagonal input matrix expected")) + m = LinearAlgebra.checksquare(A) + isdiag(A) || throw(DimensionMismatch("diagonal input matrix expected")) D, V = DV @assert D isa Diagonal && V isa AbstractMatrix @check_size(D, (m, m)) @@ -40,10 +38,10 @@ function check_input(::typeof(eig_full!), A::AbstractMatrix, DV, ::DiagonalAlgor return nothing end function check_input(::typeof(eig_vals!), A::AbstractMatrix, D, ::DiagonalAlgorithm) - m, n = size(A) - ((m == n) && isdiag(A)) || throw(DimensionMismatch("diagonal input matrix expected")) + m = LinearAlgebra.checksquare(A) + isdiag(A) || throw(DimensionMismatch("diagonal input matrix expected")) @assert D isa AbstractVector - @check_size(D, (n,)) + @check_size(D, (m,)) @check_scalar(D, A, complex) return nothing end diff --git a/src/implementations/eigh.jl b/src/implementations/eigh.jl index 19a190b1..d17e25e8 100644 --- a/src/implementations/eigh.jl +++ b/src/implementations/eigh.jl @@ -11,8 +11,7 @@ copy_input(::typeof(eigh_full), A::Diagonal) = copy(A) check_hermitian(A, ::AbstractAlgorithm) = check_hermitian(A) check_hermitian(A, alg::Algorithm) = check_hermitian(A; atol = get(alg.kwargs, :hermitian_tol, default_hermitian_tol(A))) function check_hermitian(A; atol::Real = default_hermitian_tol(A), rtol::Real = 0) - m, n = size(A) - m == n || throw(DimensionMismatch("square input matrix expected")) + LinearAlgebra.checksquare(A) ishermitian(A; atol, rtol) || throw(DomainError(A, "Hermitian matrix was expected. Use `project_hermitian` to project onto the nearest hermitian matrix.")) return nothing diff --git a/src/implementations/gen_eig.jl b/src/implementations/gen_eig.jl index f4a6a7f5..043da2d9 100644 --- a/src/implementations/gen_eig.jl +++ b/src/implementations/gen_eig.jl @@ -6,12 +6,9 @@ end copy_input(::typeof(gen_eig_vals), A, B) = copy_input(gen_eig_full, A, B) function check_input(::typeof(gen_eig_full!), A::AbstractMatrix, B::AbstractMatrix, WV, ::AbstractAlgorithm) - ma, na = size(A) - mb, nb = size(B) - ma == na || throw(DimensionMismatch("square input matrix A expected")) - mb == nb || throw(DimensionMismatch("square input matrix B expected")) - ma == mb || throw(DimensionMismatch("first dimension of input matrices expected to match")) - na == nb || throw(DimensionMismatch("second dimension of input matrices expected to match")) + ma = LinearAlgebra.checksquare(A) + mb = LinearAlgebra.checksquare(B) + ma == mb || throw(DimensionMismatch(lazy"Expected matching input sizes, dimensions are $ma and $mb")) W, V = WV @assert W isa Diagonal && V isa AbstractMatrix @check_size(W, (ma, ma)) @@ -23,13 +20,11 @@ function check_input(::typeof(gen_eig_full!), A::AbstractMatrix, B::AbstractMatr return nothing end function check_input(::typeof(gen_eig_vals!), A::AbstractMatrix, B::AbstractMatrix, W, ::AbstractAlgorithm) - ma, na = size(A) - mb, nb = size(B) - ma == na || throw(DimensionMismatch("square input matrix A expected")) - mb == nb || throw(DimensionMismatch("square input matrix B expected")) - ma == mb || throw(DimensionMismatch("dimension of input matrices expected to match")) + ma = LinearAlgebra.checksquare(A) + mb = LinearAlgebra.checksquare(B) + ma == mb || throw(DimensionMismatch(lazy"Expected matching input sizes, dimensions are $ma and $mb")) @assert W isa AbstractVector - @check_size(W, (na,)) + @check_size(W, (ma,)) @check_scalar(W, A, complex) @check_scalar(W, B, complex) return nothing diff --git a/src/implementations/schur.jl b/src/implementations/schur.jl index 3cad9e0f..a193c912 100644 --- a/src/implementations/schur.jl +++ b/src/implementations/schur.jl @@ -5,23 +5,21 @@ copy_input(::typeof(schur_vals), A) = copy_input(eig_vals, A) # check input function check_input(::typeof(schur_full!), A::AbstractMatrix, TZv, ::AbstractAlgorithm) - m, n = size(A) - m == n || throw(DimensionMismatch("square input matrix expected")) + m = LinearAlgebra.checksquare(A) T, Z, vals = TZv @assert T isa AbstractMatrix && Z isa AbstractMatrix && vals isa AbstractVector @check_size(T, (m, m)) @check_scalar(T, A) @check_size(Z, (m, m)) @check_scalar(Z, A) - @check_size(vals, (n,)) + @check_size(vals, (m,)) @check_scalar(vals, A, complex) return nothing end function check_input(::typeof(schur_vals!), A::AbstractMatrix, vals, ::AbstractAlgorithm) - m, n = size(A) - m == n || throw(DimensionMismatch("square input matrix expected")) + m = LinearAlgebra.checksquare(A) @assert vals isa AbstractVector - @check_size(vals, (n,)) + @check_size(vals, (m,)) @check_scalar(vals, A, complex) return nothing end