From 91b010dfc7d0d3bf8e7afc4c1706c93fc4e0195d Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 15 Jan 2026 14:24:24 +0100 Subject: [PATCH 1/4] consistent checksquare usage --- src/implementations/eig.jl | 18 ++++++++---------- src/implementations/eigh.jl | 3 +-- src/implementations/gen_eig.jl | 24 +++++++++++++----------- src/implementations/schur.jl | 8 +++----- 4 files changed, 25 insertions(+), 28 deletions(-) diff --git a/src/implementations/eig.jl b/src/implementations/eig.jl index 9b785898..01e4eccd 100644 --- a/src/implementations/eig.jl +++ b/src/implementations/eig.jl @@ -9,8 +9,7 @@ copy_input(::Union{typeof(eig_trunc), typeof(eig_trunc_no_error)}, A) = copy_inp copy_input(::typeof(eig_full), A::Diagonal) = copy(A) function check_input(::typeof(eig_full!), A::AbstractMatrix, DV, ::AbstractAlgorithm) - m, n = size(A) - m == n || throw(DimensionMismatch("square input matrix expected")) + m = LinearAlgebra.checksquare(A) D, V = DV @assert D isa Diagonal && V isa AbstractMatrix @check_size(D, (m, m)) @@ -20,17 +19,16 @@ function check_input(::typeof(eig_full!), A::AbstractMatrix, DV, ::AbstractAlgor return nothing end function check_input(::typeof(eig_vals!), A::AbstractMatrix, D, ::AbstractAlgorithm) - m, n = size(A) - m == n || throw(DimensionMismatch("square input matrix expected")) + m = LinearAlgebra.checksquare(A) @assert D isa AbstractVector - @check_size(D, (n,)) + @check_size(D, (m,)) @check_scalar(D, A, complex) return nothing end function check_input(::typeof(eig_full!), A::AbstractMatrix, DV, ::DiagonalAlgorithm) - m, n = size(A) - ((m == n) && isdiag(A)) || throw(DimensionMismatch("diagonal input matrix expected")) + m = LinearAlgebra.checksquare(A) + isdiag(A) || throw(DimensionMismatch("diagonal input matrix expected")) D, V = DV @assert D isa Diagonal && V isa AbstractMatrix @check_size(D, (m, m)) @@ -40,10 +38,10 @@ function check_input(::typeof(eig_full!), A::AbstractMatrix, DV, ::DiagonalAlgor return nothing end function check_input(::typeof(eig_vals!), A::AbstractMatrix, D, ::DiagonalAlgorithm) - m, n = size(A) - ((m == n) && isdiag(A)) || throw(DimensionMismatch("diagonal input matrix expected")) + m = LinearAlgebra.checksquare(A) + isdiag(A) || throw(DimensionMismatch("diagonal input matrix expected")) @assert D isa AbstractVector - @check_size(D, (n,)) + @check_size(D, (m,)) @check_scalar(D, A, complex) return nothing end diff --git a/src/implementations/eigh.jl b/src/implementations/eigh.jl index 19a190b1..d17e25e8 100644 --- a/src/implementations/eigh.jl +++ b/src/implementations/eigh.jl @@ -11,8 +11,7 @@ copy_input(::typeof(eigh_full), A::Diagonal) = copy(A) check_hermitian(A, ::AbstractAlgorithm) = check_hermitian(A) check_hermitian(A, alg::Algorithm) = check_hermitian(A; atol = get(alg.kwargs, :hermitian_tol, default_hermitian_tol(A))) function check_hermitian(A; atol::Real = default_hermitian_tol(A), rtol::Real = 0) - m, n = size(A) - m == n || throw(DimensionMismatch("square input matrix expected")) + LinearAlgebra.checksquare(A) ishermitian(A; atol, rtol) || throw(DomainError(A, "Hermitian matrix was expected. Use `project_hermitian` to project onto the nearest hermitian matrix.")) return nothing diff --git a/src/implementations/gen_eig.jl b/src/implementations/gen_eig.jl index f4a6a7f5..ecb26928 100644 --- a/src/implementations/gen_eig.jl +++ b/src/implementations/gen_eig.jl @@ -5,13 +5,17 @@ function copy_input(::typeof(gen_eig_full), A::AbstractMatrix, B::AbstractMatrix end copy_input(::typeof(gen_eig_vals), A, B) = copy_input(gen_eig_full, A, B) +@noinline function _check_gen_eig_size(A, B) + m = size(A, 1) + n = size(B, 1) + m == n || throw(DimensionMismatch(lazy"Expected matching input sizes, dimensions are $m and $n")) + return m +end + function check_input(::typeof(gen_eig_full!), A::AbstractMatrix, B::AbstractMatrix, WV, ::AbstractAlgorithm) - ma, na = size(A) - mb, nb = size(B) - ma == na || throw(DimensionMismatch("square input matrix A expected")) - mb == nb || throw(DimensionMismatch("square input matrix B expected")) - ma == mb || throw(DimensionMismatch("first dimension of input matrices expected to match")) - na == nb || throw(DimensionMismatch("second dimension of input matrices expected to match")) + ma = LinearAlgebra.checksquare(A) + mb = LinearAlgebra.checksquare(B) + ma == mb || throw(DimensionMismatch(lazy"Expected matching input sizes, dimensions are $ma and $mb")) W, V = WV @assert W isa Diagonal && V isa AbstractMatrix @check_size(W, (ma, ma)) @@ -23,11 +27,9 @@ function check_input(::typeof(gen_eig_full!), A::AbstractMatrix, B::AbstractMatr return nothing end function check_input(::typeof(gen_eig_vals!), A::AbstractMatrix, B::AbstractMatrix, W, ::AbstractAlgorithm) - ma, na = size(A) - mb, nb = size(B) - ma == na || throw(DimensionMismatch("square input matrix A expected")) - mb == nb || throw(DimensionMismatch("square input matrix B expected")) - ma == mb || throw(DimensionMismatch("dimension of input matrices expected to match")) + ma = LinearAlgebra.checksquare(A) + mb = LinearAlgebra.checksquare(B) + ma == mb || throw(DimensionMismatch(lazy"Expected matching input sizes, dimensions are $ma and $mb")) @assert W isa AbstractVector @check_size(W, (na,)) @check_scalar(W, A, complex) diff --git a/src/implementations/schur.jl b/src/implementations/schur.jl index 3cad9e0f..164b0392 100644 --- a/src/implementations/schur.jl +++ b/src/implementations/schur.jl @@ -5,8 +5,7 @@ copy_input(::typeof(schur_vals), A) = copy_input(eig_vals, A) # check input function check_input(::typeof(schur_full!), A::AbstractMatrix, TZv, ::AbstractAlgorithm) - m, n = size(A) - m == n || throw(DimensionMismatch("square input matrix expected")) + m = LinearAlgebra.checksquare(A) T, Z, vals = TZv @assert T isa AbstractMatrix && Z isa AbstractMatrix && vals isa AbstractVector @check_size(T, (m, m)) @@ -18,10 +17,9 @@ function check_input(::typeof(schur_full!), A::AbstractMatrix, TZv, ::AbstractAl return nothing end function check_input(::typeof(schur_vals!), A::AbstractMatrix, vals, ::AbstractAlgorithm) - m, n = size(A) - m == n || throw(DimensionMismatch("square input matrix expected")) + m = LinearAlgebra.checksquare(A) @assert vals isa AbstractVector - @check_size(vals, (n,)) + @check_size(vals, (m,)) @check_scalar(vals, A, complex) return nothing end From f7be69342cfc18f77039f3710000bc68400ff995 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 15 Jan 2026 15:05:23 +0100 Subject: [PATCH 2/4] remove unused function --- src/implementations/gen_eig.jl | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/implementations/gen_eig.jl b/src/implementations/gen_eig.jl index ecb26928..00cd889f 100644 --- a/src/implementations/gen_eig.jl +++ b/src/implementations/gen_eig.jl @@ -5,13 +5,6 @@ function copy_input(::typeof(gen_eig_full), A::AbstractMatrix, B::AbstractMatrix end copy_input(::typeof(gen_eig_vals), A, B) = copy_input(gen_eig_full, A, B) -@noinline function _check_gen_eig_size(A, B) - m = size(A, 1) - n = size(B, 1) - m == n || throw(DimensionMismatch(lazy"Expected matching input sizes, dimensions are $m and $n")) - return m -end - function check_input(::typeof(gen_eig_full!), A::AbstractMatrix, B::AbstractMatrix, WV, ::AbstractAlgorithm) ma = LinearAlgebra.checksquare(A) mb = LinearAlgebra.checksquare(B) From 4a987395b35810d37ba01f4aca20fa5b0eb7a592 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 15 Jan 2026 15:06:52 +0100 Subject: [PATCH 3/4] fix typo --- src/implementations/gen_eig.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/implementations/gen_eig.jl b/src/implementations/gen_eig.jl index 00cd889f..043da2d9 100644 --- a/src/implementations/gen_eig.jl +++ b/src/implementations/gen_eig.jl @@ -24,7 +24,7 @@ function check_input(::typeof(gen_eig_vals!), A::AbstractMatrix, B::AbstractMatr mb = LinearAlgebra.checksquare(B) ma == mb || throw(DimensionMismatch(lazy"Expected matching input sizes, dimensions are $ma and $mb")) @assert W isa AbstractVector - @check_size(W, (na,)) + @check_size(W, (ma,)) @check_scalar(W, A, complex) @check_scalar(W, B, complex) return nothing From d4138585f44f5fd7a2edf8431ee751eeacac97d9 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 15 Jan 2026 18:06:00 +0100 Subject: [PATCH 4/4] fix one more typo --- src/implementations/schur.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/implementations/schur.jl b/src/implementations/schur.jl index 164b0392..a193c912 100644 --- a/src/implementations/schur.jl +++ b/src/implementations/schur.jl @@ -12,7 +12,7 @@ function check_input(::typeof(schur_full!), A::AbstractMatrix, TZv, ::AbstractAl @check_scalar(T, A) @check_size(Z, (m, m)) @check_scalar(Z, A) - @check_size(vals, (n,)) + @check_size(vals, (m,)) @check_scalar(vals, A, complex) return nothing end