Skip to content

Commit 265c6ca

Browse files
committed
include memoize code in NonLinMPC constructor
1 parent 3d883eb commit 265c6ca

File tree

2 files changed

+39
-44
lines changed

2 files changed

+39
-44
lines changed

example/juMPC.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
# spell-checker: disable
22

3-
#using Pkg
4-
#using Revise
5-
#Pkg.activate(".")
3+
using Pkg
4+
using Revise
5+
Pkg.activate(".")
66

77

88
using ModelPredictiveControl
9-
#using Preferences
10-
#set_preferences!(ModelPredictiveControl, "precompile_workload" => false; force=true)
9+
using Preferences
10+
set_preferences!(ModelPredictiveControl, "precompile_workload" => false; force=true)
1111

1212

1313
#using JuMP, DAQP

src/controller/nonlinmpc.jl

Lines changed: 34 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -77,19 +77,33 @@ struct NonLinMPC{S<:StateEstimator, JEFunc<:Function} <: PredictiveController
7777
end
7878
register(optim, :J, nvar, J, autodiff=true)
7979
@NLobjective(optim, Min, J(ΔŨ...))
80-
nonlinconstraint = let mpc=mpc, model=model # capture mpc and model variables
81-
(ΔŨ...) -> con_nonlinprog(mpc, model, ΔŨ)
80+
ncon = length(mpc.con.Ŷmin) + length(mpc.con.Ŷmax)
81+
C = let mpc=mpc, model=model, ncon=ncon # capture the 3 variables
82+
last_ΔŨ, last_C = nothing, nothing
83+
function con_nonlinprog_i(i, ΔŨ::NTuple{N, Float64}) where {N}
84+
if ΔŨ !== last_ΔŨ
85+
last_C, last_ΔŨ = con_nonlinprog(mpc, model, ΔŨ), ΔŨ
86+
end
87+
return last_C[i]
88+
end
89+
last_dΔŨ, last_dCdΔŨ = nothing, nothing
90+
function con_nonlinprog_i(i, dΔŨ::NTuple{N, T}) where {N, T<:Real}
91+
if dΔŨ !== last_dΔŨ
92+
last_dCdΔŨ, last_dΔŨ = con_nonlinprog(mpc, model, dΔŨ), dΔŨ
93+
end
94+
return last_dCdΔŨ[i]
95+
end
96+
[(ΔŨ...) -> con_nonlinprog_i(i, ΔŨ) for i in 1:ncon]
8297
end
83-
nonlincon_memoized = memoize(nonlinconstraint, 2*ny*Hp)
8498
n = 0
8599
for i in eachindex(con.Ŷmin)
86100
sym = Symbol("C_Ŷmin_$i")
87-
register(optim, sym, nvar, nonlincon_memoized[n + i], autodiff=true)
101+
register(optim, sym, nvar, C[n + i], autodiff=true)
88102
end
89103
n = lastindex(con.Ŷmin)
90104
for i in eachindex(con.Ŷmax)
91105
sym = Symbol("C_Ŷmax_$i")
92-
register(optim, sym, nvar, nonlincon_memoized[n + i], autodiff=true)
106+
register(optim, sym, nvar, C[n + i], autodiff=true)
93107
end
94108
set_silent(optim)
95109
return mpc
@@ -227,7 +241,7 @@ init_objective!(mpc::NonLinMPC, _ ) = nothing
227241
"""
228242
obj_nonlinprog(mpc::NonLinMPC, model::LinModel, ΔŨ::NTuple{N, T}) where {N, T}
229243
230-
TBW
244+
Objective function for [`NonLinMPC`] when `model` is a [`LinModel`](@ref).
231245
"""
232246
function obj_nonlinprog(mpc::NonLinMPC, model::LinModel, ΔŨ::NTuple{N, T}) where {N, T}
233247
ΔŨ = collect(ΔŨ) # convert NTuple to Vector
@@ -242,7 +256,7 @@ end
242256
"""
243257
obj_nonlinprog(mpc::NonLinMPC, model::SimModel, ΔŨ::NTuple{N, T}) where {N, T}
244258
245-
TBW
259+
Objective function for [`NonLinMPC`] when `model` is not a [`LinModel`](@ref).
246260
"""
247261
function obj_nonlinprog(mpc::NonLinMPC, model::SimModel, ΔŨ::NTuple{N, T}) where {N, T}
248262
ΔŨ = collect(ΔŨ) # convert NTuple to Vector
@@ -271,13 +285,18 @@ function obj_nonlinprog(mpc::NonLinMPC, model::SimModel, ΔŨ::NTuple{N, T}) wh
271285
end
272286

273287

288+
"""
289+
con_nonlinprog(mpc::NonLinMPC, ::LinModel, ΔŨ::NTuple{N, T}) where {N, T}
290+
291+
Nonlinear constraints for [`NonLinMPC`](@ref) when `model` is a [`LinModel`](@ref).
292+
"""
274293
function con_nonlinprog(mpc::NonLinMPC, ::LinModel, ΔŨ::NTuple{N, T}) where {N, T}
275294
return zeros(T, 2*mpc.ny*mpc.Hp)
276295
end
277296
"""
278297
con_nonlinprog(mpc::NonLinMPC, model::NonLinModel, ΔŨ::NTuple{N, T}) where {N, T}
279298
280-
TBW
299+
Nonlinear constrains for [`NonLinMPC`](@ref) when `model` is not a [`LinModel`](ref).
281300
"""
282301
function con_nonlinprog(mpc::NonLinMPC, model::SimModel, ΔŨ::NTuple{N, T}) where {N, T}
283302
ΔŨ = collect(ΔŨ) # convert NTuple to Vector
@@ -298,6 +317,12 @@ function con_nonlinprog(mpc::NonLinMPC, model::SimModel, ΔŨ::NTuple{N, T}) wh
298317
return C
299318
end
300319

320+
321+
"""
322+
evalŶ(mpc::NonLinMPC, model::SimModel, x̂d, d0, D̂0, U0::Vector{T}) where {T}
323+
324+
Evaluate the outputs predictions ``\\mathbf{Ŷ}`` when `model` is not a [`LinModel`](@ref).
325+
"""
301326
function evalŶ(mpc::NonLinMPC, model::SimModel, x̂d, d0, D̂0, U0::Vector{T}) where {T}
302327
Ŷd0 = Vector{T}(undef, model.ny*mpc.Hp)
303328
x̂d::Vector{T} = copy(x̂d)
@@ -307,35 +332,5 @@ function evalŶ(mpc::NonLinMPC, model::SimModel, x̂d, d0, D̂0, U0::Vector{T})
307332
d0 = D̂0[(1 + model.nd*(j-1)):(model.nd*j)]
308333
Ŷd0[(1 + model.ny*(j-1)):(model.ny*j)] = h(model, x̂d, d0)
309334
end
310-
return Ŷd0 + mpc.F # mpc.F = Yop + Ŷs
311-
end
312-
313-
"""
314-
memoize(f::Function, n_outputs::Int)
315-
316-
Memoize `f` to reduce the computational cost of [`NonLinMPC`](@ref) controllers.
317-
318-
Take a function `f` and return a vector of length `n_outputs`, where element `i` is a
319-
function that returns the equivalent of `f(x...)[i]`. To avoid duplication of work, cache
320-
the most-recent evaluations of `f`. Because `f_i` is auto-differentiated with ForwardDiff,
321-
our cache needs to work when `x` is a `Float64` and a `ForwardDiff.Dual`.
322-
"""
323-
function memoize(f::Function, n_outputs::Int)
324-
last_x , last_f = nothing, nothing
325-
function f_i(i, x::NTuple{N, Float64}) where {N}
326-
if x !== last_x
327-
last_f = f(x...)
328-
last_x = x
329-
end
330-
return last_f[i]
331-
end
332-
last_dx, last_dfdx = nothing, nothing
333-
function f_i(i, dx::NTuple{N, T}) where {N, T<:Real}
334-
if dx !== last_dx
335-
last_dfdx = f(dx...)
336-
last_dx = dx
337-
end
338-
return last_dfdx[i]
339-
end
340-
return [(x...) -> f_i(i, x) for i in 1:n_outputs]
335+
return Ŷd0 + mpc.F # F = Yop + Ŷs
341336
end

0 commit comments

Comments
 (0)