Skip to content

Commit 21189c7

Browse files
committed
reduce collect calls for NonLinMPC
1 parent ccc64e9 commit 21189c7

File tree

1 file changed

+36
-34
lines changed

1 file changed

+36
-34
lines changed

src/controller/nonlinmpc.jl

Lines changed: 36 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -74,40 +74,44 @@ struct NonLinMPC{S<:StateEstimator, JEFunc<:Function} <: PredictiveController
7474
b = con.b[con.i_b]
7575
@constraint(optim, linconstraint, A*ΔŨ .≤ b)
7676

77-
last_ΔŨ, last_C, last_Ŷ = nothing, nothing, nothing
78-
function Jfunc(ΔŨ::Float64...)
79-
if ΔŨ !== last_ΔŨ
80-
last_Ŷ = predict(mpc, model, collect(ΔŨ))
81-
last_C = con_nonlinprog(mpc, model, last_Ŷ, ΔŨ)
82-
last_ΔŨ = ΔŨ
77+
last_ΔŨtup, last_C, last_Ŷ = nothing, nothing, nothing
78+
function Jfunc(ΔŨtup::Float64...)
79+
ΔŨvec = collect(ΔŨtup)
80+
if ΔŨtup !== last_ΔŨtup
81+
last_Ŷ = predict(mpc, model, ΔŨvec)
82+
last_C = con_nonlinprog(mpc, model, last_Ŷ, ΔŨvec)
83+
last_ΔŨtup = ΔŨtup
8384
end
84-
return obj_nonlinprog(mpc, model, last_Ŷ, ΔŨ)
85+
return obj_nonlinprog(mpc, model, last_Ŷ, ΔŨvec)
8586
end
86-
last_dΔŨ, last_dC, last_dŶ = nothing, nothing, nothing
87-
function Jfunc(dΔŨ::T...) where {T<:Real}
88-
if dΔŨ !== last_dΔŨ
89-
last_dŶ = predict(mpc, model, collect(dΔŨ))
90-
last_dC = con_nonlinprog(mpc, model, last_dŶ, dΔŨ)
91-
last_dΔŨ = dΔŨ
87+
last_dΔŨtup, last_dC, last_dŶ = nothing, nothing, nothing
88+
function Jfunc(dΔŨtup::T...) where {T<:Real}
89+
dΔŨvec = collect(dΔŨtup)
90+
if dΔŨtup !== last_dΔŨtup
91+
last_dŶ = predict(mpc, model, dΔŨvec)
92+
last_dC = con_nonlinprog(mpc, model, last_dŶ, dΔŨvec)
93+
last_dΔŨtup = dΔŨtup
9294
end
93-
return obj_nonlinprog(mpc, model, last_dŶ, dΔŨ)
95+
return obj_nonlinprog(mpc, model, last_dŶ, dΔŨvec)
9496
end
9597
register(optim, :Jfunc, nvar, Jfunc, autodiff=true)
9698
@NLobjective(optim, Min, Jfunc(ΔŨ...))
9799
ncon = length(mpc.con.Ŷmin) + length(mpc.con.Ŷmax)
98-
function con_nonlinprog_i(i, ΔŨ::NTuple{N, Float64}) where {N}
99-
if ΔŨ !== last_ΔŨ
100-
last_Ŷ = predict(mpc, model, collect(ΔŨ))
101-
last_C = con_nonlinprog(mpc, model, last_Ŷ, ΔŨ)
102-
last_ΔŨ = ΔŨ
100+
function con_nonlinprog_i(i, ΔŨtup::NTuple{N, Float64}) where {N}
101+
if ΔŨtup !== last_ΔŨtup
102+
ΔŨvec = collect(ΔŨtup)
103+
last_Ŷ = predict(mpc, model, ΔŨvec)
104+
last_C = con_nonlinprog(mpc, model, last_Ŷ, ΔŨvec)
105+
last_ΔŨtup = ΔŨtup
103106
end
104107
return last_C[i]
105108
end
106-
function con_nonlinprog_i(i, dΔŨ::NTuple{N, T}) where {N, T<:Real}
107-
if dΔŨ !== last_dΔŨ
108-
last_dŶ = predict(mpc, model, collect(dΔŨ))
109-
last_dC = con_nonlinprog(mpc, model, last_dŶ, dΔŨ)
110-
last_dΔŨ = dΔŨ
109+
function con_nonlinprog_i(i, dΔŨtup::NTuple{N, T}) where {N, T<:Real}
110+
if dΔŨtup !== last_dΔŨtup
111+
dΔŨvec = collect(dΔŨtup)
112+
last_dŶ = predict(mpc, model, dΔŨvec)
113+
last_dC = con_nonlinprog(mpc, model, last_dŶ, dΔŨvec)
114+
last_dΔŨtup = dΔŨtup
111115
end
112116
return last_dC[i]
113117
end
@@ -266,12 +270,11 @@ init_objective!(mpc::NonLinMPC, _ ) = nothing
266270

267271

268272
"""
269-
obj_nonlinprog(mpc::NonLinMPC, model::LinModel, ΔŨ::NTuple{N, T}) where {N, T}
273+
obj_nonlinprog(mpc::NonLinMPC, model::LinModel, ΔŨ::Vector{Real})
270274
271275
Objective function for [`NonLinMPC`] when `model` is a [`LinModel`](@ref).
272276
"""
273-
function obj_nonlinprog(mpc::NonLinMPC, model::LinModel, Ŷ, ΔŨ::NTuple{N, T}) where {N, T}
274-
ΔŨ = collect(ΔŨ) # convert NTuple to Vector
277+
function obj_nonlinprog(mpc::NonLinMPC, model::LinModel, Ŷ, ΔŨ::Vector{T}) where {T<:Real}
275278
Jqp = obj_quadprog(ΔŨ, mpc.P̃, mpc.q̃)
276279
U = mpc.S̃_Hp*ΔŨ + mpc.T_Hp*(mpc.estim.lastu0 + model.uop)
277280
UE = [U; U[(end - model.nu + 1):end]]
@@ -281,12 +284,11 @@ function obj_nonlinprog(mpc::NonLinMPC, model::LinModel, Ŷ, ΔŨ::NTuple{N, T
281284
end
282285

283286
"""
284-
obj_nonlinprog(mpc::NonLinMPC, model::SimModel, ΔŨ::NTuple{N, T}) where {N, T}
287+
obj_nonlinprog(mpc::NonLinMPC, model::SimModel, ΔŨ::Vector{Real})
285288
286289
Objective function for [`NonLinMPC`] when `model` is not a [`LinModel`](@ref).
287290
"""
288-
function obj_nonlinprog(mpc::NonLinMPC, model::SimModel, Ŷ, ΔŨ::NTuple{N, T}) where {N, T}
289-
ΔŨ = collect(ΔŨ) # convert NTuple to Vector
291+
function obj_nonlinprog(mpc::NonLinMPC, model::SimModel, Ŷ, ΔŨ::Vector{T}) where {T<:Real}
290292
U0 = mpc.S̃_Hp*ΔŨ + mpc.T_Hp*(mpc.estim.lastu0)
291293
# --- output setpoint tracking term ---
292294
êy = mpc.R̂y -
@@ -312,19 +314,19 @@ end
312314

313315

314316
"""
315-
con_nonlinprog(mpc::NonLinMPC, ::LinModel, ΔŨ::NTuple{N, T}) where {N, T}
317+
con_nonlinprog(mpc::NonLinMPC, ::LinModel, ΔŨ::Vector{Real})
316318
317319
Nonlinear constraints for [`NonLinMPC`](@ref) when `model` is a [`LinModel`](@ref).
318320
"""
319-
function con_nonlinprog(mpc::NonLinMPC, model::LinModel, _, ΔŨ::NTuple{N, T}) where {N, T}
321+
function con_nonlinprog(mpc::NonLinMPC, model::LinModel, _, ΔŨ::Vector{T}) where {T<:Real}
320322
return zeros(T, 2*model.ny*mpc.Hp)
321323
end
322324
"""
323-
con_nonlinprog(mpc::NonLinMPC, model::NonLinModel, ΔŨ::NTuple{N, T}) where {N, T}
325+
con_nonlinprog(mpc::NonLinMPC, model::NonLinModel, ΔŨ::Vector{Real})
324326
325327
Nonlinear constrains for [`NonLinMPC`](@ref) when `model` is not a [`LinModel`](ref).
326328
"""
327-
function con_nonlinprog(mpc::NonLinMPC, ::SimModel, Ŷ, ΔŨ::NTuple{N, T}) where {N, T}
329+
function con_nonlinprog(mpc::NonLinMPC, ::SimModel, Ŷ, ΔŨ::Vector{T}) where {T<:Real}
328330
if !isinf(mpc.C) # constraint softening activated :
329331
ϵ = ΔŨ[end]
330332
C_Ŷmin = (mpc.con.Ŷmin - Ŷ) - ϵ*mpc.con.c_Ŷmin

0 commit comments

Comments
 (0)