Skip to content

Commit 4462cd7

Browse files
committed
reduce allocation for NonLinMPC
1 parent 725ec21 commit 4462cd7

File tree

2 files changed

+32
-20
lines changed

2 files changed

+32
-20
lines changed

src/controller/execute.jl

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,8 @@ julia> round.(getinfo(mpc)[:Ŷ], digits=3)
104104
"""
105105
function getinfo(mpc::PredictiveController{NT}) where NT<:Real
106106
info = Dict{Symbol, Union{JuMP._SolutionSummary, Vector{NT}, NT}}()
107-
Ŷ, x̂ = similar(mpc.Ŷop), similar(mpc.estim.x̂)
108-
Ŷ, x̂end = predict!(Ŷ, x̂, mpc, mpc.estim.model, mpc.ΔŨ)
107+
Ŷ, x̂, u0 = similar(mpc.Ŷop), similar(mpc.estim.), similar(mpc.estim.lastu0)
108+
Ŷ, x̂end = predict!(Ŷ, x̂, u0, mpc, mpc.estim.model, mpc.ΔŨ)
109109
info[:ΔU] = mpc.ΔŨ[1:mpc.Hc*mpc.estim.model.nu]
110110
info[] = isinf(mpc.C) ? NaN : mpc.ΔŨ[end]
111111
info[:J] = obj_nonlinprog(mpc, mpc.estim.model, Ŷ, mpc.ΔŨ)
@@ -263,14 +263,16 @@ function linconstraint!(mpc::PredictiveController, model::SimModel)
263263
end
264264

265265
@doc raw"""
266-
predict!(Ŷ, x̂, mpc::PredictiveController, model::LinModel, ΔŨ) -> Ŷ, x̂end
266+
predict!(Ŷ, x̂, _ , mpc::PredictiveController, model::LinModel, ΔŨ) -> Ŷ, x̂end
267267
268268
Compute the predictions `Ŷ` and terminal states `x̂end` if model is a [`LinModel`](@ref).
269269
270270
The method mutates `Ŷ` and `x̂` vector arguments. The `x̂end` vector is used for
271271
the terminal constraints applied on ``\mathbf{x̂}_{k-1}(k+H_p)``.
272272
"""
273-
function predict!(Ŷ, x̂, mpc::PredictiveController, ::LinModel, ΔŨ::Vector{NT}) where {NT<:Real}
273+
function predict!(
274+
Ŷ, x̂, _ , mpc::PredictiveController, ::LinModel, ΔŨ::Vector{NT}
275+
) where {NT<:Real}
274276
# in-place operations to reduce allocations :
275277
Ŷ .= mul!(Ŷ, mpc.Ẽ, ΔŨ) .+ mpc.F
276278
x̂ .= mul!(x̂, mpc.con.ẽx̂, ΔŨ) .+ mpc.con.fx̂
@@ -279,14 +281,19 @@ function predict!(Ŷ, x̂, mpc::PredictiveController, ::LinModel, ΔŨ::Vector
279281
end
280282

281283
@doc raw"""
282-
predict!(Ŷ, x̂, mpc::PredictiveController, model::SimModel, ΔŨ) -> Ŷ, x̂end
284+
predict!(Ŷ, x̂, u0, mpc::PredictiveController, model::SimModel, ΔŨ) -> Ŷ, x̂end
283285
284-
Compute both vectors if `model` is not a [`LinModel`](@ref).
286+
Compute both vectors if `model` is not a [`LinModel`](@ref).
287+
288+
The method mutates `Ŷ`, `x̂` and `u0` arguments. The latter is the manipulated input without
289+
the operating points ``\mathbf{u_0}(k) = \mathbf{u}(k) - \mathbf{u_{op}}(k)``.
285290
"""
286-
function predict!(Ŷ, x̂, mpc::PredictiveController, model::SimModel, ΔŨ::Vector{NT}) where {NT<:Real}
291+
function predict!(
292+
Ŷ, x̂, u0, mpc::PredictiveController, model::SimModel, ΔŨ::Vector{NT}
293+
) where {NT<:Real}
287294
nu, ny, nd, Hp, Hc = model.nu, model.ny, model.nd, mpc.Hp, mpc.Hc
288-
x̂ .= mpc.estim.
289-
u0::Vector{NT} = copy(mpc.estim.lastu0)
295+
.= mpc.estim.
296+
u0 .= mpc.estim.lastu0
290297
d0 = @views mpc.d0[1:end]
291298
for j=1:Hp
292299
if j Hc

src/controller/nonlinmpc.jl

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -297,20 +297,22 @@ function init_optimization!(mpc::NonLinMPC, optim::JuMP.GenericModel{JNT}) where
297297
end
298298
end
299299
model = mpc.estim.model
300-
ny, nx̂, Hp, ng = model.ny, mpc.estim.nx̂, mpc.Hp, length(con.i_g)
300+
nu, ny, nx̂, Hp, ng = model.nu, model.ny, mpc.estim.nx̂, mpc.Hp, length(con.i_g)
301301
# inspired from https://jump.dev/JuMP.jl/stable/tutorials/nonlinear/tips_and_tricks/#User-defined-operators-with-vector-outputs
302-
Jfunc, gfunc = let mpc=mpc, model=model, ng=ng, nΔŨ=nΔŨ , nŶ=Hp*ny, nx̂=nx̂
302+
Jfunc, gfunc = let mpc=mpc, model=model, ng=ng, nΔŨ=nΔŨ , nŶ=Hp*ny, nx̂=nx̂, nu=nu
303303
last_ΔŨtup_float, last_ΔŨtup_dual = nothing, nothing
304-
Ŷ_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nŶ), nΔŨ + 3)
305-
g_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, ng), nΔŨ + 3)
306-
x̂_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nx̂), nΔŨ + 3)
304+
Ŷ_cache ::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nŶ), nΔŨ + 3)
305+
g_cache ::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, ng), nΔŨ + 3)
306+
x̂_cache ::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nx̂), nΔŨ + 3)
307+
u0_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nu), nΔŨ + 3)
307308
function Jfunc(ΔŨtup::JNT...)
308309
= get_tmp(Ŷ_cache, ΔŨtup[1])
309310
ΔŨ = collect(ΔŨtup)
310311
if ΔŨtup !== last_ΔŨtup_float
311-
= get_tmp(x̂_cache, ΔŨtup[1])
312-
g = get_tmp(g_cache, ΔŨtup[1])
313-
Ŷ, x̂end = predict!(Ŷ, x̂, mpc, model, ΔŨ)
312+
= get_tmp(x̂_cache, ΔŨtup[1])
313+
u0 = get_tmp(u0_cache, ΔŨtup[1])
314+
g = get_tmp(g_cache, ΔŨtup[1])
315+
Ŷ, x̂end = predict!(Ŷ, x̂, u0, mpc, model, ΔŨ)
314316
g = con_nonlinprog!(g, mpc, model, x̂end, Ŷ, ΔŨ)
315317
last_ΔŨtup_float = ΔŨtup
316318
end
@@ -321,8 +323,9 @@ function init_optimization!(mpc::NonLinMPC, optim::JuMP.GenericModel{JNT}) where
321323
ΔŨ = collect(ΔŨtup)
322324
if ΔŨtup !== last_ΔŨtup_dual
323325
= get_tmp(x̂_cache, ΔŨtup[1])
326+
u0 = get_tmp(u0_cache, ΔŨtup[1])
324327
g = get_tmp(g_cache, ΔŨtup[1])
325-
Ŷ, x̂end = predict!(Ŷ, x̂, mpc, model, ΔŨ)
328+
Ŷ, x̂end = predict!(Ŷ, x̂, u0, mpc, model, ΔŨ)
326329
g = con_nonlinprog!(g, mpc, model, x̂end, Ŷ, ΔŨ)
327330
last_ΔŨtup_dual = ΔŨtup
328331
end
@@ -332,9 +335,10 @@ function init_optimization!(mpc::NonLinMPC, optim::JuMP.GenericModel{JNT}) where
332335
g = get_tmp(g_cache, ΔŨtup[1])
333336
if ΔŨtup !== last_ΔŨtup_float
334337
= get_tmp(x̂_cache, ΔŨtup[1])
338+
u0 = get_tmp(u0_cache, ΔŨtup[1])
335339
= get_tmp(Ŷ_cache, ΔŨtup[1])
336340
ΔŨ = collect(ΔŨtup)
337-
Ŷ, x̂end = predict!(Ŷ, x̂, mpc, model, ΔŨ)
341+
Ŷ, x̂end = predict!(Ŷ, x̂, u0, mpc, model, ΔŨ)
338342
g = con_nonlinprog!(g, mpc, model, x̂end, Ŷ, ΔŨ)
339343
last_ΔŨtup_float = ΔŨtup
340344
end
@@ -344,9 +348,10 @@ function init_optimization!(mpc::NonLinMPC, optim::JuMP.GenericModel{JNT}) where
344348
g = get_tmp(g_cache, ΔŨtup[1])
345349
if ΔŨtup !== last_ΔŨtup_dual
346350
= get_tmp(x̂_cache, ΔŨtup[1])
351+
u0 = get_tmp(u0_cache, ΔŨtup[1])
347352
= get_tmp(Ŷ_cache, ΔŨtup[1])
348353
ΔŨ = collect(ΔŨtup)
349-
Ŷ, x̂end = predict!(Ŷ, x̂, mpc, model, ΔŨ)
354+
Ŷ, x̂end = predict!(Ŷ, x̂, u0, mpc, model, ΔŨ)
350355
g = con_nonlinprog!(g, mpc, model, x̂end, Ŷ, ΔŨ)
351356
last_ΔŨtup_dual = ΔŨtup
352357
end

0 commit comments

Comments
 (0)