Skip to content

Commit 3cdc70b

Browse files
committed
reduce allocation NonLinMPC
1 parent e44b2e1 commit 3cdc70b

File tree

1 file changed

+61
-58
lines changed

1 file changed

+61
-58
lines changed

src/controller/nonlinmpc.jl

Lines changed: 61 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ end
281281
282282
Init the nonlinear optimization for [`NonLinMPC`](@ref) controllers.
283283
"""
284-
function init_optimization!(mpc::NonLinMPC, optim::JuMP.GenericModel{JNT}) where JNT<:Real
284+
function init_optimization!(mpc::NonLinMPC, optim)
285285
# --- variables and linear constraints ---
286286
C, con = mpc.C, mpc.con
287287
nΔŨ = length(mpc.ΔŨ)
@@ -300,65 +300,12 @@ function init_optimization!(mpc::NonLinMPC, optim::JuMP.GenericModel{JNT}) where
300300
JuMP.set_attribute(optim, "nlp_scaling_max_gradient", 10.0/C)
301301
end
302302
end
303-
model = mpc.estim.model
304-
nu, ny, nx̂, Hp, ng = model.nu, model.ny, mpc.estim.nx̂, mpc.Hp, length(con.i_g)
305-
# inspired from https://jump.dev/JuMP.jl/stable/tutorials/nonlinear/tips_and_tricks/#User-defined-operators-with-vector-outputs
306-
Jfunc, gfunc = let mpc=mpc, model=model, ng=ng, nΔŨ=nΔŨ, nŶ=Hp*ny, nx̂=nx̂, nu=nu, nU=Hp*nu
307-
Nc = nΔŨ + 3
308-
last_ΔŨtup_float, last_ΔŨtup_dual = nothing, nothing
309-
ΔŨ_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nΔŨ), Nc)
310-
Ŷ_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nŶ), Nc)
311-
U_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nU), Nc)
312-
g_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, ng), Nc)
313-
x̂_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nx̂), Nc)
314-
x̂next_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nx̂), Nc)
315-
u_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nu), Nc)
316-
û_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nu), Nc)
317-
Ȳ_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nŶ), Nc)
318-
Ū_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nU), Nc)
319-
function Jfunc(ΔŨtup::T...)::T where {T <: Real}
320-
ΔŨ1 = ΔŨtup[begin]
321-
if T == JNT
322-
last_ΔŨtup_float = ΔŨtup
323-
else
324-
last_ΔŨtup_dual = ΔŨtup
325-
end
326-
ΔŨ, Ŷ = get_tmp(ΔŨ_cache, ΔŨ1), get_tmp(Ŷ_cache, ΔŨ1)
327-
x̂, x̂next = get_tmp(x̂_cache, ΔŨ1), get_tmp(x̂next_cache, ΔŨ1)
328-
u, û = get_tmp(u_cache, ΔŨ1), get_tmp(û_cache, ΔŨ1)
329-
ΔŨ .= ΔŨtup
330-
Ŷ, x̂end = predict!(Ŷ, x̂, x̂next, u, û, mpc, model, ΔŨ)
331-
g = get_tmp(g_cache, ΔŨ1)
332-
g = con_nonlinprog!(g, mpc, model, x̂end, Ŷ, ΔŨ)
333-
U, Ȳ, Ū = get_tmp(U_cache, ΔŨ1), get_tmp(Ȳ_cache, ΔŨ1), get_tmp(Ū_cache, ΔŨ1)
334-
return obj_nonlinprog!(U, Ȳ, Ū, mpc, model, Ŷ, ΔŨ)::T
335-
end
336-
function gfunc_i(i, ΔŨtup::NTuple{N, T})::T where {N, T <:Real}
337-
ΔŨ1 = ΔŨtup[begin]
338-
g = get_tmp(g_cache, ΔŨ1)
339-
if T == JNT
340-
isnewvalue = (ΔŨtup !== last_ΔŨtup_float)
341-
isnewvalue && (last_ΔŨtup_float = ΔŨtup)
342-
else
343-
isnewvalue = (ΔŨtup !== last_ΔŨtup_dual)
344-
isnewvalue && (last_ΔŨtup_dual = ΔŨtup)
345-
end
346-
if isnewvalue
347-
ΔŨ, Ŷ = get_tmp(ΔŨ_cache, ΔŨ1), get_tmp(Ŷ_cache, ΔŨ1)
348-
x̂, x̂next = get_tmp(x̂_cache, ΔŨ1), get_tmp(x̂next_cache, ΔŨ1)
349-
u, û = get_tmp(u_cache, ΔŨ1), get_tmp(û_cache, ΔŨ1)
350-
ΔŨ .= ΔŨtup
351-
Ŷ, x̂end = predict!(Ŷ, x̂, x̂next, u, û, mpc, model, ΔŨ)
352-
g = con_nonlinprog!(g, mpc, model, x̂end, Ŷ, ΔŨ)
353-
end
354-
return g[i]::T
355-
end
356-
gfunc = [(ΔŨ...) -> gfunc_i(i, ΔŨ) for i in 1:ng]
357-
(Jfunc, gfunc)
358-
end
303+
Jfunc, gfunc = get_optim_functions(mpc, mpc.optim)
359304
register(optim, :Jfunc, nΔŨ, Jfunc, autodiff=true)
360305
@NLobjective(optim, Min, Jfunc(ΔŨvar...))
361-
if ng 0
306+
model = mpc.estim.model
307+
ny, nx̂, Hp = model.ny, mpc.estim.nx̂, mpc.Hp
308+
if length(con.i_g) 0
362309
for i in eachindex(con.Ymin)
363310
sym = Symbol("g_Ymin_$i")
364311
register(optim, sym, nΔŨ, gfunc[i], autodiff=true)
@@ -382,6 +329,62 @@ function init_optimization!(mpc::NonLinMPC, optim::JuMP.GenericModel{JNT}) where
382329
return nothing
383330
end
384331

332+
"""
333+
get_optim_functions(mpc::NonLinMPC, ::JuMP.GenericModel) -> Jfunc, gfunc
334+
335+
Get the objective `Jfunc` and constraints `gfunc` functions for [`NonLinMPC`](@ref).
336+
337+
Inspired from: [User-defined operators with vector outputs](https://jump.dev/JuMP.jl/stable/tutorials/nonlinear/tips_and_tricks/#User-defined-operators-with-vector-outputs)
338+
"""
339+
function get_optim_functions(mpc::NonLinMPC, ::JuMP.GenericModel{JNT}) where JNT<:Real
340+
model = mpc.estim.model
341+
nu, ny, nx̂, Hp = model.nu, model.ny, mpc.estim.nx̂, mpc.Hp
342+
ng, nΔŨ, nU, nŶ = length(mpc.con.i_g), length(mpc.ΔŨ), Hp*nu, Hp*ny
343+
Nc = nΔŨ + 3
344+
ΔŨ_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nΔŨ), Nc)
345+
Ŷ_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nŶ), Nc)
346+
U_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nU), Nc)
347+
g_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, ng), Nc)
348+
x̂_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nx̂), Nc)
349+
x̂next_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nx̂), Nc)
350+
u_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nu), Nc)
351+
û_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nu), Nc)
352+
Ȳ_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nŶ), Nc)
353+
Ū_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nU), Nc)
354+
function Jfunc(ΔŨtup::T...) where T<:Real
355+
ΔŨ1 = ΔŨtup[begin]
356+
ΔŨ, g = get_tmp(ΔŨ_cache, ΔŨ1), get_tmp(g_cache, ΔŨ1)
357+
for i in eachindex(ΔŨtup)
358+
ΔŨ[i] = ΔŨtup[i] # ΔŨ .= ΔŨtup seems to produce a type instability
359+
end
360+
= get_tmp(Ŷ_cache, ΔŨ1)
361+
x̂, x̂next = get_tmp(x̂_cache, ΔŨ1), get_tmp(x̂next_cache, ΔŨ1)
362+
u, û = get_tmp(u_cache, ΔŨ1), get_tmp(û_cache, ΔŨ1)
363+
Ŷ, x̂end = predict!(Ŷ, x̂, x̂next, u, û, mpc, model, ΔŨ)
364+
g = get_tmp(g_cache, ΔŨ1)
365+
g = con_nonlinprog!(g, mpc, model, x̂end, Ŷ, ΔŨ)
366+
U, Ȳ, Ū = get_tmp(U_cache, ΔŨ1), get_tmp(Ȳ_cache, ΔŨ1), get_tmp(Ū_cache, ΔŨ1)
367+
return obj_nonlinprog!(U, Ȳ, Ū, mpc, model, Ŷ, ΔŨ)::T
368+
end
369+
function gfunc_i(i, ΔŨtup::NTuple{N, T}) where {N, T<:Real}
370+
ΔŨ1 = ΔŨtup[begin]
371+
ΔŨ, g = get_tmp(ΔŨ_cache, ΔŨ1), get_tmp(g_cache, ΔŨ1)
372+
if any(new old for (new, old) in zip(ΔŨtup, ΔŨ)) # new ΔŨtup, update predictions:
373+
for i in eachindex(ΔŨtup)
374+
ΔŨ[i] = ΔŨtup[i] # ΔŨ .= ΔŨtup seems to produce a type instability
375+
end
376+
= get_tmp(Ŷ_cache, ΔŨ1)
377+
x̂, x̂next = get_tmp(x̂_cache, ΔŨ1), get_tmp(x̂next_cache, ΔŨ1)
378+
u, û = get_tmp(u_cache, ΔŨ1), get_tmp(û_cache, ΔŨ1)
379+
Ŷ, x̂end = predict!(Ŷ, x̂, x̂next, u, û, mpc, model, ΔŨ)
380+
g = con_nonlinprog!(g, mpc, model, x̂end, Ŷ, ΔŨ)
381+
end
382+
return g[i]::T
383+
end
384+
gfunc = [(ΔŨ...) -> gfunc_i(i, ΔŨ) for i in 1:ng]
385+
return Jfunc, gfunc
386+
end
387+
385388
"Set the nonlinear constraints on the output predictions `Ŷ` and terminal states `x̂end`."
386389
function setnonlincon!(mpc::NonLinMPC, ::NonLinModel)
387390
optim = mpc.optim

0 commit comments

Comments
 (0)