Skip to content

Commit fd52b78

Browse files
committed
added: hessian now works well for NonLinMPC 🍾
1 parent bbf48dd commit fd52b78

File tree

2 files changed

+137
-50
lines changed

2 files changed

+137
-50
lines changed

src/controller/nonlinmpc.jl

Lines changed: 116 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -642,34 +642,37 @@ function get_nonlinops(mpc::NonLinMPC, optim::JuMP.GenericModel{JNT}) where JNT<
642642
Û0::Vector{JNT}, X̂0::Vector{JNT} = zeros(JNT, nU), zeros(JNT, nX̂)
643643
gc::Vector{JNT}, g::Vector{JNT} = zeros(JNT, nc), zeros(JNT, ng)
644644
gi::Vector{JNT}, geq::Vector{JNT} = zeros(JNT, ngi), zeros(JNT, neq)
645+
λi::Vector{JNT}, λeq::Vector{JNT} = zeros(JNT, ngi), zeros(JNT, neq)
645646
# -------------- inequality constraint: nonlinear oracle -----------------------------
646647
function gi!(gi, Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, geq, g)
647648
update_predictions!(ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, g, geq, mpc, Z̃)
648649
gi .= @views g[i_g]
649650
return nothing
650651
end
651-
function ℓ_gi(Z̃_λ_gi, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, geq, g, gi)
652-
Z̃, λ = @views Z̃_λ_gi[begin:begin+nZ̃-1], Z̃_λ_gi[begin+nZ̃:end]
652+
function ℓ_gi(Z̃, λi, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, geq, g, gi)
653653
update_predictions!(ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, g, geq, mpc, Z̃)
654654
gi .= @views g[i_g]
655-
return dot(λ, gi)
655+
return dot(λi, gi)
656656
end
657657
Z̃_∇gi = fill(myNaN, nZ̃) # NaN to force update at first call
658-
Z̃_λ_gi = fill(myNaN, nZ̃ + ngi)
659658
∇gi_context = (
660659
Cache(ΔŨ), Cache(x̂0end), Cache(Ue), Cache(Ŷe), Cache(U0), Cache(Ŷ0),
661660
Cache(Û0), Cache(K0), Cache(X̂0),
662661
Cache(gc), Cache(geq), Cache(g)
663662
)
664663
∇gi_prep = prepare_jacobian(gi!, gi, jac, Z̃_∇gi, ∇gi_context...; strict)
665-
∇²gi_context = (
666-
Cache(ΔŨ), Cache(x̂0end), Cache(Ue), Cache(Ŷe), Cache(U0), Cache(Ŷ0),
667-
Cache(Û0), Cache(K0), Cache(X̂0),
668-
Cache(gc), Cache(geq), Cache(g), Cache(gi)
669-
)
670-
∇²gi_prep = prepare_hessian(ℓ_gi, hess, Z̃_λ_gi, ∇²gi_context...; strict)
671-
∇gi = init_diffmat(JNT, jac, ∇gi_prep, nZ̃, ngi)
672-
∇²ℓ_gi = init_diffmat(JNT, hess, ∇²gi_prep, nZ̃ + ngi, nZ̃ + ngi)
664+
∇gi = init_diffmat(JNT, jac, ∇gi_prep, nZ̃, ngi)
665+
∇gi_structure = init_diffstructure(∇gi)
666+
if !isnothing(hess)
667+
∇²gi_context = (
668+
Cache(ΔŨ), Cache(x̂0end), Cache(Ue), Cache(Ŷe), Cache(U0), Cache(Ŷ0),
669+
Cache(Û0), Cache(K0), Cache(X̂0),
670+
Cache(gc), Cache(geq), Cache(g), Cache(gi)
671+
)
672+
∇²gi_prep = prepare_hessian(ℓ_gi, hess, Z̃_∇gi, Constant(λi), ∇²gi_context...; strict)
673+
∇²ℓ_gi = init_diffmat(JNT, hess, ∇²gi_prep, nZ̃, nZ̃)
674+
∇²gi_structure = lowertriangle_indices(init_diffstructure(∇²ℓ_gi))
675+
end
673676
function update_con!(gi, ∇gi, Z̃_∇gi, Z̃_arg)
674677
if isdifferent(Z̃_arg, Z̃_∇gi)
675678
Z̃_∇gi .= Z̃_arg
@@ -682,45 +685,55 @@ function get_nonlinops(mpc::NonLinMPC, optim::JuMP.GenericModel{JNT}) where JNT<
682685
return gi_arg .= gi
683686
end
684687
function ∇gi_func!(∇gi_arg, Z̃_arg)
685-
update_con!(gi, ∇gi, Z̃_∇gi, Z̃_arg)
686-
return diffmat2vec!(∇gi_arg, ∇gi)
688+
update_con!(gi, ∇gi, Z̃_∇gi, Z̃_arg)
689+
return diffmat2vec!(∇gi_arg, ∇gi, ∇gi_structure)
687690
end
688691
function ∇²gi_func!(∇²ℓ_arg, Z̃_arg, λ_arg)
689-
Z̃_λ_gi[1:begin:begin+nZ̃-1] .= Z̃_arg
690-
Z̃_λ_gi[[begin+nZ̃:end]] .= λ_arg
691-
hessian!(ℓ_gi, ∇²ℓ_gi, ∇²gi_prep, hess, Z̃_λ_gi, ∇²gi_context)
692-
return diffmat2vec!(∇²ℓ_arg, ∇²ℓ_gi)
692+
Z̃_∇gi .= Z̃_arg
693+
λi .= λ_arg
694+
hessian!(ℓ_gi, ∇²ℓ_gi, ∇²gi_prep, hess, Z̃_∇gi, Constant(λi), ∇²gi_context...)
695+
return diffmat2vec!(∇²ℓ_arg, ∇²ℓ_gi, ∇²gi_structure)
693696
end
694697
gi_min = fill(-myInf, ngi)
695698
gi_max = zeros(JNT, ngi)
696-
∇gi_structure = init_diffstructure(∇gi)
697-
∇²gi_structure = init_diffstructure(∇²ℓ_gi)
698-
display(∇²ℓ_gi)
699699
g_oracle = MOI.VectorNonlinearOracle(;
700700
dimension = nZ̃,
701701
l = gi_min,
702702
u = gi_max,
703703
eval_f = gi_func!,
704704
jacobian_structure = ∇gi_structure,
705705
eval_jacobian = ∇gi_func!,
706-
hessian_lagrangian_structure = ∇²gi_structure,
707-
eval_hessian_lagrangian = ∇²gi_func!
706+
hessian_lagrangian_structure = isnothing(hess) ? Tuple{Int,Int}[] : ∇²gi_structure,
707+
eval_hessian_lagrangian = isnothing(hess) ? nothing : ∇²gi_func!
708708
)
709-
#TODO: verify if I must fill only upper/lower triangular part ?
710-
#TODO: add Hessian for 1. Jfunc and 2. geq
711709
# ------------- equality constraints : nonlinear oracle ------------------------------
712710
function geq!(geq, Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, g)
713711
update_predictions!(ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, g, geq, mpc, Z̃)
714712
return nothing
715713
end
714+
function ℓ_geq(Z̃, λeq, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, geq, g)
715+
update_predictions!(ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, g, geq, mpc, Z̃)
716+
return dot(λeq, geq)
717+
end
716718
Z̃_∇geq = fill(myNaN, nZ̃) # NaN to force update at first call
717719
∇geq_context = (
718720
Cache(ΔŨ), Cache(x̂0end), Cache(Ue), Cache(Ŷe), Cache(U0), Cache(Ŷ0),
719721
Cache(Û0), Cache(K0), Cache(X̂0),
720722
Cache(gc), Cache(g)
721723
)
722724
∇geq_prep = prepare_jacobian(geq!, geq, jac, Z̃_∇geq, ∇geq_context...; strict)
723-
∇geq = init_diffmat(JNT, jac, ∇geq_prep, nZ̃, neq)
725+
∇geq = init_diffmat(JNT, jac, ∇geq_prep, nZ̃, neq)
726+
∇geq_structure = init_diffstructure(∇geq)
727+
if !isnothing(hess)
728+
∇²geq_context = (
729+
Cache(ΔŨ), Cache(x̂0end), Cache(Ue), Cache(Ŷe), Cache(U0), Cache(Ŷ0),
730+
Cache(Û0), Cache(K0), Cache(X̂0),
731+
Cache(gc), Cache(geq), Cache(g)
732+
)
733+
∇²geq_prep = prepare_hessian(ℓ_geq, hess, Z̃_∇geq, Constant(λeq), ∇²geq_context...; strict)
734+
∇²ℓ_geq = init_diffmat(JNT, hess, ∇²geq_prep, nZ̃, nZ̃)
735+
∇²geq_structure = lowertriangle_indices(init_diffstructure(∇²ℓ_geq))
736+
end
724737
function update_con_eq!(geq, ∇geq, Z̃_∇geq, Z̃_arg)
725738
if isdifferent(Z̃_arg, Z̃_∇geq)
726739
Z̃_∇geq .= Z̃_arg
@@ -734,53 +747,109 @@ function get_nonlinops(mpc::NonLinMPC, optim::JuMP.GenericModel{JNT}) where JNT<
734747
end
735748
function ∇geq_func!(∇geq_arg, Z̃_arg)
736749
update_con_eq!(geq, ∇geq, Z̃_∇geq, Z̃_arg)
737-
return diffmat2vec!(∇geq_arg, ∇geq)
750+
return diffmat2vec!(∇geq_arg, ∇geq, ∇geq_structure)
751+
end
752+
function ∇²geq_func!(∇²ℓ_arg, Z̃_arg, λ_arg)
753+
Z̃_∇geq .= Z̃_arg
754+
λeq .= λ_arg
755+
hessian!(ℓ_geq, ∇²ℓ_geq, ∇²geq_prep, hess, Z̃_∇geq, Constant(λeq), ∇²geq_context...)
756+
return diffmat2vec!(∇²ℓ_arg, ∇²ℓ_geq, ∇²geq_structure)
738757
end
739758
geq_min = geq_max = zeros(JNT, neq)
740-
∇geq_structure = init_diffstructure(∇geq)
741759
geq_oracle = MOI.VectorNonlinearOracle(;
742760
dimension = nZ̃,
743761
l = geq_min,
744762
u = geq_max,
745763
eval_f = geq_func!,
746764
jacobian_structure = ∇geq_structure,
747-
eval_jacobian = ∇geq_func!
765+
eval_jacobian = ∇geq_func!,
766+
hessian_lagrangian_structure = isnothing(hess) ? Tuple{Int,Int}[] : ∇²geq_structure,
767+
eval_hessian_lagrangian = isnothing(hess) ? nothing : ∇²geq_func!
748768
)
749769
# ------------- objective function: splatting syntax ---------------------------------
750770
function J!(Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, g, geq)
751771
update_predictions!(ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, g, geq, mpc, Z̃)
752772
return obj_nonlinprog!(Ŷ0, U0, mpc, model, Ue, Ŷe, ΔŨ)
753773
end
754-
Z̃_∇J = fill(myNaN, nZ̃) # NaN to force update at first call
755-
J_context = (
774+
Z̃_J = fill(myNaN, nZ̃) # NaN to force update at first call
775+
J_context = (
756776
Cache(ΔŨ), Cache(x̂0end), Cache(Ue), Cache(Ŷe), Cache(U0), Cache(Ŷ0),
757777
Cache(Û0), Cache(K0), Cache(X̂0),
758778
Cache(gc), Cache(g), Cache(geq),
759779
)
760-
∇J_prep = prepare_gradient(J!, grad, Z̃_∇J, ∇J_context...; strict)
761-
∇J = Vector{JNT}(undef, nZ̃)
762-
function update_objective!(J, ∇J, Z̃_∇J, Z̃_arg)
763-
if isdifferent(Z̃_arg, Z̃_∇J)
764-
Z̃_∇J .= Z̃_arg
765-
J[], _ = value_and_gradient!(J!, ∇J, ∇J_prep, grad, Z̃_∇J, ∇J_context...)
780+
∇J_prep = prepare_gradient(J!, grad, Z̃_J, J_context...; strict)
781+
∇J = Vector{JNT}(undef, nZ̃)
782+
if !isnothing(hess)
783+
∇²J_prep = prepare_hessian(J!, hess, Z̃_J, J_context...; strict)
784+
∇²J = init_diffmat(JNT, hess, ∇²J_prep, nZ̃, nZ̃)
785+
end
786+
update_objective! = if !isnothing(hess)
787+
function (J, ∇J, ∇²J, Z̃_J, Z̃_arg)
788+
if isdifferent(Z̃_arg, Z̃_J)
789+
Z̃_J .= Z̃_arg
790+
J[], _ = value_gradient_and_hessian!(J!, ∇J, ∇²J, hess, Z̃_J, J_context...)
791+
end
792+
end
793+
else
794+
update_objective! = function (J, ∇J, Z̃_∇J, Z̃_arg)
795+
if isdifferent(Z̃_arg, Z̃_∇J)
796+
Z̃_∇J .= Z̃_arg
797+
J[], _ = value_and_gradient!(J!, ∇J, ∇J_prep, grad, Z̃_∇J, J_context...)
798+
end
799+
end
800+
end
801+
J_func = if !isnothing(hess)
802+
function (Z̃_arg::Vararg{T, N}) where {N, T<:Real}
803+
update_objective!(J, ∇J, ∇²J, Z̃_J, Z̃_arg)
804+
return J[]::T
805+
end
806+
else
807+
function (Z̃_arg::Vararg{T, N}) where {N, T<:Real}
808+
update_objective!(J, ∇J, Z̃_J, Z̃_arg)
809+
return J[]::T
766810
end
767-
end
768-
function J_func(Z̃_arg::Vararg{T, N}) where {N, T<:Real}
769-
update_objective!(J, ∇J, Z̃_∇J, Z̃_arg)
770-
return J[]::T
771811
end
772812
∇J_func! = if nZ̃ == 1 # univariate syntax (see JuMP.@operator doc):
813+
if !isnothing(hess)
814+
function (Z̃_arg)
815+
update_objective!(J, ∇J, ∇²J, Z̃_J, Z̃_arg)
816+
return ∇J[]
817+
end
818+
else
819+
function (Z̃_arg)
820+
update_objective!(J, ∇J, Z̃_J, Z̃_arg)
821+
return ∇J[]
822+
end
823+
end
824+
else # multivariate syntax (see JuMP.@operator doc):
825+
if !isnothing(hess)
826+
function (∇J_arg::AbstractVector{T}, Z̃_arg::Vararg{T, N}) where {N, T<:Real}
827+
update_objective!(J, ∇J, ∇²J, Z̃_J, Z̃_arg)
828+
return ∇J_arg .= ∇J
829+
end
830+
else
831+
function (∇J_arg::AbstractVector{T}, Z̃_arg::Vararg{T, N}) where {N, T<:Real}
832+
update_objective!(J, ∇J, Z̃_J, Z̃_arg)
833+
return ∇J_arg .= ∇J
834+
end
835+
end
836+
end
837+
∇²J_func! = if nZ̃ == 1 # univariate syntax (see JuMP.@operator doc):
773838
function (Z̃_arg)
774-
update_objective!(J, ∇J, Z̃_∇J, Z̃_arg)
775-
return ∇J[]
839+
update_objective!(J, ∇J, ∇²J, Z̃_J, Z̃_arg)
840+
return²J[]
776841
end
777842
else # multivariate syntax (see JuMP.@operator doc):
778-
function (∇J_arg::AbstractVector{T}, Z̃_arg::Vararg{T, N}) where {N, T<:Real}
779-
update_objective!(J, ∇J, Z̃_∇J, Z̃_arg)
780-
return J_arg .= ∇J
843+
function (∇²J_arg::AbstractMatrix{T}, Z̃_arg::Vararg{T, N}) where {N, T<:Real}
844+
update_objective!(J, ∇J, ∇²J, Z̃_J, Z̃_arg)
845+
return fill_lowertriangle!(∇²J_arg, ∇²J)
781846
end
782847
end
783-
J_op = JuMP.add_nonlinear_operator(optim, nZ̃, J_func, ∇J_func!, name=:J_op)
848+
J_op = if !isnothing(hess)
849+
JuMP.add_nonlinear_operator(optim, nZ̃, J_func, ∇J_func!, ∇²J_func!, name=:J_op)
850+
else
851+
JuMP.add_nonlinear_operator(optim, nZ̃, J_func, ∇J_func!, name=:J_op)
852+
end
784853
return g_oracle, geq_oracle, J_op
785854
end
786855

src/general.jl

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,27 @@ function init_diffstructure(A::AbstractSparseMatrix)
7171
end
7272
init_diffstructure(A::AbstractMatrix)= Tuple.(CartesianIndices(A))[:]
7373

74-
"Store the differentiation matrix `A` in the vector `v` as required by `JuMP.jl.`"
75-
diffmat2vec!(v::AbstractVector, A::AbstractSparseMatrix) = v .= nonzeros(A)
76-
diffmat2vec!(v::AbstractVector, A::AbstractMatrix) = v[:] = A
74+
"Get the lower-triangular indices from the differentiation matrix structure."
75+
function lowertriangle_indices(diffmat_struct::Vector{Tuple{Int, Int}})
76+
return [(i,j) for (i,j) in diffmat_struct if i j]
77+
end
78+
79+
"Fill the lower triangular part of A in-place with the corresponding part in B."
80+
function fill_lowertriangle!(A::AbstractMatrix, B::AbstractMatrix)
81+
for j in axes(A, 2), i in axes(A, 1)
82+
(i j) && (A[i, j] = B[i, j])
83+
end
84+
return A
85+
end
86+
87+
"Store the diff. matrix `A` in the vector `v` with list of nonzero indices `i_vec`"
88+
function diffmat2vec!(v::AbstractVector, A::AbstractMatrix, i_vec::Vector{Tuple{Int, Int}})
89+
for i in eachindex(v)
90+
i_A, j_A = i_vec[i]
91+
v[i] = A[i_A, j_A]
92+
end
93+
return v
94+
end
7795

7896
backend_str(backend::AbstractADType) = string(nameof(typeof(backend)))
7997
function backend_str(backend::AutoSparse)

0 commit comments

Comments
 (0)