Skip to content

Commit 674e4ab

Browse files
committed
reduce allocation NonLinMPC
1 parent d9f024f commit 674e4ab

File tree

2 files changed

+40
-31
lines changed

2 files changed

+40
-31
lines changed

src/controller/execute.jl

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -106,10 +106,13 @@ function getinfo(mpc::PredictiveController{NT}) where NT<:Real
106106
info = Dict{Symbol, Union{JuMP._SolutionSummary, Vector{NT}, NT}}()
107107
Ŷ, x̂, u0 = similar(mpc.Ŷop), similar(mpc.estim.x̂), similar(mpc.estim.lastu0)
108108
Ŷ, x̂end = predict!(Ŷ, x̂, u0, mpc, mpc.estim.model, mpc.ΔŨ)
109+
U = mpc.*mpc.ΔŨ + mpc.T*(mpc.estim.lastu0 + mpc.estim.model.uop)
110+
Ȳ, Ū = similar(Ŷ), similar(U)
111+
J = obj_nonlinprog!(Ȳ, Ū, mpc, mpc.estim.model, Ŷ, mpc.ΔŨ)
109112
info[:ΔU] = mpc.ΔŨ[1:mpc.Hc*mpc.estim.model.nu]
110113
info[] = isinf(mpc.C) ? NaN : mpc.ΔŨ[end]
111-
info[:J] = obj_nonlinprog(mpc, mpc.estim.model, Ŷ, mpc.ΔŨ)
112-
info[:U] = mpc.*mpc.ΔŨ + mpc.T*(mpc.estim.lastu0 + mpc.estim.model.uop)
114+
info[:J] = J
115+
info[:U] = U
113116
info[:u] = info[:U][1:mpc.estim.model.nu]
114117
info[:d] = mpc.d0 + mpc.estim.model.dop
115118
info[:D̂] = mpc.D̂0 + mpc.Dop
@@ -309,16 +312,17 @@ function predict!(
309312
end
310313

311314
"""
312-
obj_nonlinprog(mpc::PredictiveController, model::LinModel, Ŷ, ΔŨ)
315+
obj_nonlinprog!(_ , _ , mpc::PredictiveController, model::LinModel, Ŷ, ΔŨ)
313316
314317
Nonlinear programming objective function when `model` is a [`LinModel`](@ref).
315318
316319
The function is called by the nonlinear optimizer of [`NonLinMPC`](@ref) controllers. It can
317320
also be called on any [`PredictiveController`](@ref)s to evaluate the objective function `J`
318-
at specific input increments `ΔŨ` and predictions `Ŷ` values.
321+
at specific input increments `ΔŨ` and predictions `Ŷ` values. This method does not mutate
322+
its argument.
319323
"""
320-
function obj_nonlinprog(
321-
mpc::PredictiveController, model::LinModel, Ŷ, ΔŨ::Vector{NT}
324+
function obj_nonlinprog!(
325+
_ , _ , mpc::PredictiveController, model::LinModel, Ŷ, ΔŨ::Vector{NT}
322326
) where {NT<:Real}
323327
J = obj_quadprog(ΔŨ, mpc.H̃, mpc.q̃) + mpc.p[]
324328
if !iszero(mpc.E)
@@ -331,17 +335,18 @@ function obj_nonlinprog(
331335
end
332336

333337
"""
334-
obj_nonlinprog(mpc::PredictiveController, model::SimModel, Ŷ, ΔŨ)
338+
obj_nonlinprog!(Ȳ, Ū. mpc::PredictiveController, model::SimModel, Ŷ, ΔŨ)
335339
336340
Nonlinear programming objective function when `model` is not a [`LinModel`](@ref). The
337-
function `dot(x, A, x)` is a performant way of calculating `x'*A*x`.
341+
function `dot(x, A, x)` is a performant way of calculating `x'*A*x`. This method mutates
342+
`Ȳ` and `Ū` vector arguments (output and input setpoint tracking error, respectively).
338343
"""
339-
function obj_nonlinprog(
340-
mpc::PredictiveController, model::SimModel, Ŷ, ΔŨ::Vector{NT}
344+
function obj_nonlinprog!(
345+
Ȳ, Ū, mpc::PredictiveController, model::SimModel, Ŷ, ΔŨ::Vector{NT}
341346
) where {NT<:Real}
342347
# --- output setpoint tracking term ---
343-
êy = mpc.R̂y -
344-
JR̂y = dot(êy, mpc.M_Hp, êy)
348+
Ȳ .= mpc.R̂y .-
349+
JR̂y = dot(, mpc.M_Hp, )
345350
# --- move suppression and slack variable term ---
346351
JΔŨ = dot(ΔŨ, mpc.Ñ_Hc, ΔŨ)
347352
# --- input over prediction horizon ---
@@ -350,8 +355,8 @@ function obj_nonlinprog(
350355
end
351356
# --- input setpoint tracking term ---
352357
if !mpc.noR̂u
353-
êu = mpc.R̂u - U
354-
JR̂u = dot(êu, mpc.L_Hp, êu)
358+
Ū .= mpc.R̂u .- U
359+
JR̂u = dot(, mpc.L_Hp, )
355360
else
356361
JR̂u = 0.0
357362
end

src/controller/nonlinmpc.jl

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -305,38 +305,42 @@ function init_optimization!(mpc::NonLinMPC, optim::JuMP.GenericModel{JNT}) where
305305
g_cache ::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, ng), nΔŨ + 3)
306306
x̂_cache ::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nx̂), nΔŨ + 3)
307307
u0_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nu), nΔŨ + 3)
308+
Ȳ_cache ::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nŶ), nΔŨ + 3)
309+
Ū_cache ::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nu), nΔŨ + 3)
308310
function Jfunc(ΔŨtup::JNT...)
309-
= get_tmp(Ŷ_cache, ΔŨtup[1])
311+
ΔŨtud1 = ΔŨtup[begin]
312+
= get_tmp(Ŷ_cache, ΔŨtud1)
310313
ΔŨ = collect(ΔŨtup)
311314
if ΔŨtup !== last_ΔŨtup_float
312-
= get_tmp(x̂_cache, ΔŨtup[1])
313-
u0 = get_tmp(u0_cache, ΔŨtup[1])
315+
x̂, u0 = get_tmp(x̂_cache, ΔŨtud1), get_tmp(u0_cache, ΔŨtud1)
314316
g = get_tmp(g_cache, ΔŨtup[1])
315317
Ŷ, x̂end = predict!(Ŷ, x̂, u0, mpc, model, ΔŨ)
316318
g = con_nonlinprog!(g, mpc, model, x̂end, Ŷ, ΔŨ)
317319
last_ΔŨtup_float = ΔŨtup
318320
end
319-
return obj_nonlinprog(mpc, model, Ŷ, ΔŨ)
321+
Ȳ, Ū = get_tmp(Ȳ_cache, ΔŨtud1), get_tmp(Ū_cache, ΔŨtud1)
322+
return obj_nonlinprog!(Ȳ, Ū, mpc, model, Ŷ, ΔŨ)
320323
end
321324
function Jfunc(ΔŨtup::ForwardDiff.Dual...)
322-
= get_tmp(Ŷ_cache, ΔŨtup[1])
325+
ΔŨtud1 = ΔŨtup[begin]
326+
= get_tmp(Ŷ_cache, ΔŨtud1)
323327
ΔŨ = collect(ΔŨtup)
324328
if ΔŨtup !== last_ΔŨtup_dual
325-
= get_tmp(x̂_cache, ΔŨtup[1])
326-
u0 = get_tmp(u0_cache, ΔŨtup[1])
327-
g = get_tmp(g_cache, ΔŨtup[1])
329+
x̂, u0 = get_tmp(x̂_cache, ΔŨtud1), get_tmp(u0_cache, ΔŨtud1)
330+
g = get_tmp(g_cache, ΔŨtud1)
328331
Ŷ, x̂end = predict!(Ŷ, x̂, u0, mpc, model, ΔŨ)
329332
g = con_nonlinprog!(g, mpc, model, x̂end, Ŷ, ΔŨ)
330333
last_ΔŨtup_dual = ΔŨtup
331334
end
332-
return obj_nonlinprog(mpc, model, Ŷ, ΔŨ)
335+
Ȳ, Ū = get_tmp(Ȳ_cache, ΔŨtud1), get_tmp(Ū_cache, ΔŨtud1)
336+
return obj_nonlinprog!(Ȳ, Ū, mpc, model, Ŷ, ΔŨ)
333337
end
334338
function gfunc_i(i, ΔŨtup::NTuple{N, JNT}) where N
335-
g = get_tmp(g_cache, ΔŨtup[1])
339+
ΔŨtud1 = ΔŨtup[begin]
340+
g = get_tmp(g_cache, ΔŨtud1)
336341
if ΔŨtup !== last_ΔŨtup_float
337-
= get_tmp(x̂_cache, ΔŨtup[1])
338-
u0 = get_tmp(u0_cache, ΔŨtup[1])
339-
= get_tmp(Ŷ_cache, ΔŨtup[1])
342+
x̂, u0 = get_tmp(x̂_cache, ΔŨtud1), get_tmp(u0_cache, ΔŨtud1)
343+
= get_tmp(Ŷ_cache, ΔŨtud1)
340344
ΔŨ = collect(ΔŨtup)
341345
Ŷ, x̂end = predict!(Ŷ, x̂, u0, mpc, model, ΔŨ)
342346
g = con_nonlinprog!(g, mpc, model, x̂end, Ŷ, ΔŨ)
@@ -345,11 +349,11 @@ function init_optimization!(mpc::NonLinMPC, optim::JuMP.GenericModel{JNT}) where
345349
return g[i]
346350
end
347351
function gfunc_i(i, ΔŨtup::NTuple{N, ForwardDiff.Dual}) where N
348-
g = get_tmp(g_cache, ΔŨtup[1])
352+
ΔŨtud1 = ΔŨtup[begin]
353+
g = get_tmp(g_cache, ΔŨtud1)
349354
if ΔŨtup !== last_ΔŨtup_dual
350-
= get_tmp(x̂_cache, ΔŨtup[1])
351-
u0 = get_tmp(u0_cache, ΔŨtup[1])
352-
= get_tmp(Ŷ_cache, ΔŨtup[1])
355+
x̂, u0 = get_tmp(x̂_cache, ΔŨtud1), get_tmp(u0_cache, ΔŨtud1)
356+
= get_tmp(Ŷ_cache, ΔŨtud1)
353357
ΔŨ = collect(ΔŨtup)
354358
Ŷ, x̂end = predict!(Ŷ, x̂, u0, mpc, model, ΔŨ)
355359
g = con_nonlinprog!(g, mpc, model, x̂end, Ŷ, ΔŨ)

0 commit comments

Comments
 (0)