@@ -642,34 +642,37 @@ function get_nonlinops(mpc::NonLinMPC, optim::JuMP.GenericModel{JNT}) where JNT<
642642 Û0:: Vector{JNT} , X̂0:: Vector{JNT} = zeros (JNT, nU), zeros (JNT, nX̂)
643643 gc:: Vector{JNT} , g:: Vector{JNT} = zeros (JNT, nc), zeros (JNT, ng)
644644 gi:: Vector{JNT} , geq:: Vector{JNT} = zeros (JNT, ngi), zeros (JNT, neq)
645+ λi:: Vector{JNT} , λeq:: Vector{JNT} = zeros (JNT, ngi), zeros (JNT, neq)
645646 # -------------- inequality constraint: nonlinear oracle -----------------------------
646647 function gi! (gi, Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, geq, g)
647648 update_predictions! (ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, g, geq, mpc, Z̃)
648649 gi .= @views g[i_g]
649650 return nothing
650651 end
651- function ℓ_gi (Z̃_λ_gi, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, geq, g, gi)
652- Z̃, λ = @views Z̃_λ_gi[begin : begin + nZ̃- 1 ], Z̃_λ_gi[begin + nZ̃:end ]
652+ function ℓ_gi (Z̃, λi, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, geq, g, gi)
653653 update_predictions! (ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, g, geq, mpc, Z̃)
654654 gi .= @views g[i_g]
655- return dot (λ , gi)
655+ return dot (λi , gi)
656656 end
657657 Z̃_∇gi = fill (myNaN, nZ̃) # NaN to force update at first call
658- Z̃_λ_gi = fill (myNaN, nZ̃ + ngi)
659658 ∇gi_context = (
660659 Cache (ΔŨ), Cache (x̂0end), Cache (Ue), Cache (Ŷe), Cache (U0), Cache (Ŷ0),
661660 Cache (Û0), Cache (K0), Cache (X̂0),
662661 Cache (gc), Cache (geq), Cache (g)
663662 )
664663 ∇gi_prep = prepare_jacobian (gi!, gi, jac, Z̃_∇gi, ∇gi_context... ; strict)
665- ∇²gi_context = (
666- Cache (ΔŨ), Cache (x̂0end), Cache (Ue), Cache (Ŷe), Cache (U0), Cache (Ŷ0),
667- Cache (Û0), Cache (K0), Cache (X̂0),
668- Cache (gc), Cache (geq), Cache (g), Cache (gi)
669- )
670- ∇²gi_prep = prepare_hessian (ℓ_gi, hess, Z̃_λ_gi, ∇²gi_context... ; strict)
671- ∇gi = init_diffmat (JNT, jac, ∇gi_prep, nZ̃, ngi)
672- ∇²ℓ_gi = init_diffmat (JNT, hess, ∇²gi_prep, nZ̃ + ngi, nZ̃ + ngi)
664+ ∇gi = init_diffmat (JNT, jac, ∇gi_prep, nZ̃, ngi)
665+ ∇gi_structure = init_diffstructure (∇gi)
666+ if ! isnothing (hess)
667+ ∇²gi_context = (
668+ Cache (ΔŨ), Cache (x̂0end), Cache (Ue), Cache (Ŷe), Cache (U0), Cache (Ŷ0),
669+ Cache (Û0), Cache (K0), Cache (X̂0),
670+ Cache (gc), Cache (geq), Cache (g), Cache (gi)
671+ )
672+ ∇²gi_prep = prepare_hessian (ℓ_gi, hess, Z̃_∇gi, Constant (λi), ∇²gi_context... ; strict)
673+ ∇²ℓ_gi = init_diffmat (JNT, hess, ∇²gi_prep, nZ̃, nZ̃)
674+ ∇²gi_structure = lowertriangle_indices (init_diffstructure (∇²ℓ_gi))
675+ end
673676 function update_con! (gi, ∇gi, Z̃_∇gi, Z̃_arg)
674677 if isdifferent (Z̃_arg, Z̃_∇gi)
675678 Z̃_∇gi .= Z̃_arg
@@ -682,45 +685,55 @@ function get_nonlinops(mpc::NonLinMPC, optim::JuMP.GenericModel{JNT}) where JNT<
682685 return gi_arg .= gi
683686 end
684687 function ∇gi_func! (∇gi_arg, Z̃_arg)
685- update_con! (gi, ∇gi, Z̃_∇gi, Z̃_arg)
686- return diffmat2vec! (∇gi_arg, ∇gi)
688+ update_con! (gi, ∇gi, Z̃_∇gi, Z̃_arg)
689+ return diffmat2vec! (∇gi_arg, ∇gi, ∇gi_structure )
687690 end
688691 function ∇²gi_func! (∇²ℓ_arg, Z̃_arg, λ_arg)
689- Z̃_λ_gi[ 1 : begin : begin + nZ̃ - 1 ] .= Z̃_arg
690- Z̃_λ_gi[[ begin + nZ̃ :end ]] .= λ_arg
691- hessian! (ℓ_gi, ∇²ℓ_gi, ∇²gi_prep, hess, Z̃_λ_gi, ∇²gi_context)
692- return diffmat2vec! (∇²ℓ_arg, ∇²ℓ_gi)
692+ Z̃_∇gi .= Z̃_arg
693+ λi .= λ_arg
694+ hessian! (ℓ_gi, ∇²ℓ_gi, ∇²gi_prep, hess, Z̃_∇gi, Constant (λi), ∇²gi_context... )
695+ return diffmat2vec! (∇²ℓ_arg, ∇²ℓ_gi, ∇²gi_structure )
693696 end
694697 gi_min = fill (- myInf, ngi)
695698 gi_max = zeros (JNT, ngi)
696- ∇gi_structure = init_diffstructure (∇gi)
697- ∇²gi_structure = init_diffstructure (∇²ℓ_gi)
698- display (∇²ℓ_gi)
699699 g_oracle = MOI. VectorNonlinearOracle (;
700700 dimension = nZ̃,
701701 l = gi_min,
702702 u = gi_max,
703703 eval_f = gi_func!,
704704 jacobian_structure = ∇gi_structure,
705705 eval_jacobian = ∇gi_func!,
706- hessian_lagrangian_structure = ∇²gi_structure,
707- eval_hessian_lagrangian = ∇²gi_func!
706+ hessian_lagrangian_structure = isnothing (hess) ? Tuple{Int,Int}[] : ∇²gi_structure,
707+ eval_hessian_lagrangian = isnothing (hess) ? nothing : ∇²gi_func!
708708 )
709- # TODO : verify if I must fill only upper/lower triangular part ?
710- # TODO : add Hessian for 1. Jfunc and 2. geq
711709 # ------------- equality constraints : nonlinear oracle ------------------------------
712710 function geq! (geq, Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, g)
713711 update_predictions! (ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, g, geq, mpc, Z̃)
714712 return nothing
715713 end
714+ function ℓ_geq (Z̃, λeq, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, geq, g)
715+ update_predictions! (ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, g, geq, mpc, Z̃)
716+ return dot (λeq, geq)
717+ end
716718 Z̃_∇geq = fill (myNaN, nZ̃) # NaN to force update at first call
717719 ∇geq_context = (
718720 Cache (ΔŨ), Cache (x̂0end), Cache (Ue), Cache (Ŷe), Cache (U0), Cache (Ŷ0),
719721 Cache (Û0), Cache (K0), Cache (X̂0),
720722 Cache (gc), Cache (g)
721723 )
722724 ∇geq_prep = prepare_jacobian (geq!, geq, jac, Z̃_∇geq, ∇geq_context... ; strict)
723- ∇geq = init_diffmat (JNT, jac, ∇geq_prep, nZ̃, neq)
725+ ∇geq = init_diffmat (JNT, jac, ∇geq_prep, nZ̃, neq)
726+ ∇geq_structure = init_diffstructure (∇geq)
727+ if ! isnothing (hess)
728+ ∇²geq_context = (
729+ Cache (ΔŨ), Cache (x̂0end), Cache (Ue), Cache (Ŷe), Cache (U0), Cache (Ŷ0),
730+ Cache (Û0), Cache (K0), Cache (X̂0),
731+ Cache (gc), Cache (geq), Cache (g)
732+ )
733+ ∇²geq_prep = prepare_hessian (ℓ_geq, hess, Z̃_∇geq, Constant (λeq), ∇²geq_context... ; strict)
734+ ∇²ℓ_geq = init_diffmat (JNT, hess, ∇²geq_prep, nZ̃, nZ̃)
735+ ∇²geq_structure = lowertriangle_indices (init_diffstructure (∇²ℓ_geq))
736+ end
724737 function update_con_eq! (geq, ∇geq, Z̃_∇geq, Z̃_arg)
725738 if isdifferent (Z̃_arg, Z̃_∇geq)
726739 Z̃_∇geq .= Z̃_arg
@@ -734,53 +747,109 @@ function get_nonlinops(mpc::NonLinMPC, optim::JuMP.GenericModel{JNT}) where JNT<
734747 end
735748 function ∇geq_func! (∇geq_arg, Z̃_arg)
736749 update_con_eq! (geq, ∇geq, Z̃_∇geq, Z̃_arg)
737- return diffmat2vec! (∇geq_arg, ∇geq)
750+ return diffmat2vec! (∇geq_arg, ∇geq, ∇geq_structure)
751+ end
752+ function ∇²geq_func! (∇²ℓ_arg, Z̃_arg, λ_arg)
753+ Z̃_∇geq .= Z̃_arg
754+ λeq .= λ_arg
755+ hessian! (ℓ_geq, ∇²ℓ_geq, ∇²geq_prep, hess, Z̃_∇geq, Constant (λeq), ∇²geq_context... )
756+ return diffmat2vec! (∇²ℓ_arg, ∇²ℓ_geq, ∇²geq_structure)
738757 end
739758 geq_min = geq_max = zeros (JNT, neq)
740- ∇geq_structure = init_diffstructure (∇geq)
741759 geq_oracle = MOI. VectorNonlinearOracle (;
742760 dimension = nZ̃,
743761 l = geq_min,
744762 u = geq_max,
745763 eval_f = geq_func!,
746764 jacobian_structure = ∇geq_structure,
747- eval_jacobian = ∇geq_func!
765+ eval_jacobian = ∇geq_func!,
766+ hessian_lagrangian_structure = isnothing (hess) ? Tuple{Int,Int}[] : ∇²geq_structure,
767+ eval_hessian_lagrangian = isnothing (hess) ? nothing : ∇²geq_func!
748768 )
749769 # ------------- objective function: splatting syntax ---------------------------------
750770 function J! (Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, g, geq)
751771 update_predictions! (ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, g, geq, mpc, Z̃)
752772 return obj_nonlinprog! (Ŷ0, U0, mpc, model, Ue, Ŷe, ΔŨ)
753773 end
754- Z̃_∇J = fill (myNaN, nZ̃) # NaN to force update at first call
755- ∇ J_context = (
774+ Z̃_J = fill (myNaN, nZ̃) # NaN to force update at first call
775+ J_context = (
756776 Cache (ΔŨ), Cache (x̂0end), Cache (Ue), Cache (Ŷe), Cache (U0), Cache (Ŷ0),
757777 Cache (Û0), Cache (K0), Cache (X̂0),
758778 Cache (gc), Cache (g), Cache (geq),
759779 )
760- ∇J_prep = prepare_gradient (J!, grad, Z̃_∇J, ∇J_context... ; strict)
761- ∇J = Vector {JNT} (undef, nZ̃)
762- function update_objective! (J, ∇J, Z̃_∇J, Z̃_arg)
763- if isdifferent (Z̃_arg, Z̃_∇J)
764- Z̃_∇J .= Z̃_arg
765- J[], _ = value_and_gradient! (J!, ∇J, ∇J_prep, grad, Z̃_∇J, ∇J_context... )
780+ ∇J_prep = prepare_gradient (J!, grad, Z̃_J, J_context... ; strict)
781+ ∇J = Vector {JNT} (undef, nZ̃)
782+ if ! isnothing (hess)
783+ ∇²J_prep = prepare_hessian (J!, hess, Z̃_J, J_context... ; strict)
784+ ∇²J = init_diffmat (JNT, hess, ∇²J_prep, nZ̃, nZ̃)
785+ end
786+ update_objective! = if ! isnothing (hess)
787+ function (J, ∇J, ∇²J, Z̃_J, Z̃_arg)
788+ if isdifferent (Z̃_arg, Z̃_J)
789+ Z̃_J .= Z̃_arg
790+ J[], _ = value_gradient_and_hessian! (J!, ∇J, ∇²J, hess, Z̃_J, J_context... )
791+ end
792+ end
793+ else
794+ update_objective! = function (J, ∇J, Z̃_∇J, Z̃_arg)
795+ if isdifferent (Z̃_arg, Z̃_∇J)
796+ Z̃_∇J .= Z̃_arg
797+ J[], _ = value_and_gradient! (J!, ∇J, ∇J_prep, grad, Z̃_∇J, J_context... )
798+ end
799+ end
800+ end
801+ J_func = if ! isnothing (hess)
802+ function (Z̃_arg:: Vararg{T, N} ) where {N, T<: Real }
803+ update_objective! (J, ∇J, ∇²J, Z̃_J, Z̃_arg)
804+ return J[]:: T
805+ end
806+ else
807+ function (Z̃_arg:: Vararg{T, N} ) where {N, T<: Real }
808+ update_objective! (J, ∇J, Z̃_J, Z̃_arg)
809+ return J[]:: T
766810 end
767- end
768- function J_func (Z̃_arg:: Vararg{T, N} ) where {N, T<: Real }
769- update_objective! (J, ∇J, Z̃_∇J, Z̃_arg)
770- return J[]:: T
771811 end
772812 ∇J_func! = if nZ̃ == 1 # univariate syntax (see JuMP.@operator doc):
813+ if ! isnothing (hess)
814+ function (Z̃_arg)
815+ update_objective! (J, ∇J, ∇²J, Z̃_J, Z̃_arg)
816+ return ∇J[]
817+ end
818+ else
819+ function (Z̃_arg)
820+ update_objective! (J, ∇J, Z̃_J, Z̃_arg)
821+ return ∇J[]
822+ end
823+ end
824+ else # multivariate syntax (see JuMP.@operator doc):
825+ if ! isnothing (hess)
826+ function (∇J_arg:: AbstractVector{T} , Z̃_arg:: Vararg{T, N} ) where {N, T<: Real }
827+ update_objective! (J, ∇J, ∇²J, Z̃_J, Z̃_arg)
828+ return ∇J_arg .= ∇J
829+ end
830+ else
831+ function (∇J_arg:: AbstractVector{T} , Z̃_arg:: Vararg{T, N} ) where {N, T<: Real }
832+ update_objective! (J, ∇J, Z̃_J, Z̃_arg)
833+ return ∇J_arg .= ∇J
834+ end
835+ end
836+ end
837+ ∇²J_func! = if nZ̃ == 1 # univariate syntax (see JuMP.@operator doc):
773838 function (Z̃_arg)
774- update_objective! (J, ∇J, Z̃_∇J , Z̃_arg)
775- return ∇J[]
839+ update_objective! (J, ∇J, ∇²J, Z̃_J , Z̃_arg)
840+ return ∇² J[]
776841 end
777842 else # multivariate syntax (see JuMP.@operator doc):
778- function (∇J_arg:: AbstractVector {T} , Z̃_arg:: Vararg{T, N} ) where {N, T<: Real }
779- update_objective! (J, ∇J, Z̃_∇J , Z̃_arg)
780- return ∇ J_arg . = ∇J
843+ function (∇² J_arg:: AbstractMatrix {T} , Z̃_arg:: Vararg{T, N} ) where {N, T<: Real }
844+ update_objective! (J, ∇J, ∇²J, Z̃_J , Z̃_arg)
845+ return fill_lowertriangle! (∇² J_arg, ∇²J)
781846 end
782847 end
783- J_op = JuMP. add_nonlinear_operator (optim, nZ̃, J_func, ∇J_func!, name= :J_op )
848+ J_op = if ! isnothing (hess)
849+ JuMP. add_nonlinear_operator (optim, nZ̃, J_func, ∇J_func!, ∇²J_func!, name= :J_op )
850+ else
851+ JuMP. add_nonlinear_operator (optim, nZ̃, J_func, ∇J_func!, name= :J_op )
852+ end
784853 return g_oracle, geq_oracle, J_op
785854end
786855
0 commit comments