@@ -304,6 +304,7 @@ function init_optimization!(mpc::NonLinMPC, optim::JuMP.GenericModel{JNT}) where
304304 Jfunc, gfunc = let mpc= mpc, model= model, ng= ng, nΔŨ= nΔŨ, nŶ= Hp* ny, nx̂= nx̂, nu= nu, nU= Hp* nu
305305 Nc = nΔŨ + 3
306306 last_ΔŨtup_float, last_ΔŨtup_dual = nothing , nothing
307+ ΔŨ_cache:: DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache (zeros (JNT, nΔŨ), Nc)
307308 Ŷ_cache:: DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache (zeros (JNT, nŶ), Nc)
308309 U_cache:: DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache (zeros (JNT, nU), Nc)
309310 g_cache:: DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache (zeros (JNT, ng), Nc)
@@ -313,63 +314,46 @@ function init_optimization!(mpc::NonLinMPC, optim::JuMP.GenericModel{JNT}) where
313314 û_cache:: DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache (zeros (JNT, nu), Nc)
314315 Ȳ_cache:: DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache (zeros (JNT, nŶ), Nc)
315316 Ū_cache:: DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache (zeros (JNT, nU), Nc)
316- function Jfunc (ΔŨtup:: JNT ... )
317+ function Jfunc (ΔŨtup:: T ... ):: T where T <: Real
317318 ΔŨ1 = ΔŨtup[begin ]
318- Ŷ = get_tmp (Ŷ_cache, ΔŨ1)
319- ΔŨ = collect (ΔŨtup)
320- if ΔŨtup != = last_ΔŨtup_float
321- x̂, x̂next = get_tmp (x̂_cache, ΔŨ1), get_tmp (x̂next_cache, ΔŨ1)
322- u, û = get_tmp (u_cache, ΔŨ1), get_tmp (û_cache, ΔŨ1)
323- Ŷ, x̂end = predict! (Ŷ, x̂, x̂next, u, û, mpc, model, ΔŨ)
324- g = get_tmp (g_cache, ΔŨ1)
325- g = con_nonlinprog! (g, mpc, model, x̂end , Ŷ, ΔŨ)
326- last_ΔŨtup_float = ΔŨtup
319+ ΔŨ, Ŷ = get_tmp (ΔŨ_cache, ΔŨ1), get_tmp (Ŷ_cache, ΔŨ1)
320+ if T == JNT
321+ isnewvalue = (ΔŨtup != = last_ΔŨtup_float)
322+ isnewvalue && (last_ΔŨtup_float = ΔŨtup)
323+ else
324+ isnewvalue = (ΔŨtup != = last_ΔŨtup_dual)
325+ isnewvalue && (last_ΔŨtup_dual = ΔŨtup)
327326 end
328- U, Ȳ, Ū = get_tmp (U_cache, ΔŨ1), get_tmp (Ȳ_cache, ΔŨ1), get_tmp (Ū_cache, ΔŨ1)
329- return obj_nonlinprog! (U, Ȳ, Ū, mpc, model, Ŷ, ΔŨ)
330- end
331- function Jfunc (ΔŨtup:: ForwardDiff.Dual... )
332- ΔŨ1 = ΔŨtup[begin ]
333- Ŷ = get_tmp (Ŷ_cache, ΔŨ1)
334- ΔŨ = collect (ΔŨtup)
335- if ΔŨtup != = last_ΔŨtup_dual
327+ if isnewvalue
336328 x̂, x̂next = get_tmp (x̂_cache, ΔŨ1), get_tmp (x̂next_cache, ΔŨ1)
337329 u, û = get_tmp (u_cache, ΔŨ1), get_tmp (û_cache, ΔŨ1)
330+ ΔŨ .= ΔŨtup
338331 Ŷ, x̂end = predict! (Ŷ, x̂, x̂next, u, û, mpc, model, ΔŨ)
339332 g = get_tmp (g_cache, ΔŨ1)
340333 g = con_nonlinprog! (g, mpc, model, x̂end , Ŷ, ΔŨ)
341- last_ΔŨtup_dual = ΔŨtup
342334 end
343335 U, Ȳ, Ū = get_tmp (U_cache, ΔŨ1), get_tmp (Ȳ_cache, ΔŨ1), get_tmp (Ū_cache, ΔŨ1)
344- return obj_nonlinprog! (U, Ȳ, Ū, mpc, model, Ŷ, ΔŨ)
336+ return obj_nonlinprog! (U, Ȳ, Ū, mpc, model, Ŷ, ΔŨ):: T
345337 end
346- function gfunc_i (i, ΔŨtup:: NTuple{N, JNT} ) where N
338+ function gfunc_i (i, ΔŨtup:: NTuple{N, T} ) :: T where {N, T <: Real }
347339 ΔŨ1 = ΔŨtup[begin ]
348340 g = get_tmp (g_cache, ΔŨ1)
349- if ΔŨtup != = last_ΔŨtup_float
350- Ŷ = get_tmp (Ŷ_cache, ΔŨ1)
351- ΔŨ = collect (ΔŨtup)
352- x̂, x̂next = get_tmp (x̂_cache, ΔŨ1), get_tmp (x̂next_cache, ΔŨ1)
353- u, û = get_tmp (u_cache, ΔŨ1), get_tmp (û_cache, ΔŨ1)
354- Ŷ, x̂end = predict! (Ŷ, x̂, x̂next, u, û, mpc, model, ΔŨ)
355- g = con_nonlinprog! (g, mpc, model, x̂end , Ŷ, ΔŨ)
356- last_ΔŨtup_float = ΔŨtup
341+ if T == JNT
342+ isnewvalue = (ΔŨtup != = last_ΔŨtup_float)
343+ isnewvalue && (last_ΔŨtup_float = ΔŨtup)
344+ else
345+ isnewvalue = (ΔŨtup != = last_ΔŨtup_dual)
346+ isnewvalue && (last_ΔŨtup_dual = ΔŨtup)
357347 end
358- return g[i]
359- end
360- function gfunc_i (i, ΔŨtup:: NTuple{N, ForwardDiff.Dual} ) where N
361- ΔŨ1 = ΔŨtup[begin ]
362- g = get_tmp (g_cache, ΔŨ1)
363- if ΔŨtup != = last_ΔŨtup_dual
364- Ŷ = get_tmp (Ŷ_cache, ΔŨ1)
365- ΔŨ = collect (ΔŨtup)
348+ if isnewvalue
349+ ΔŨ, Ŷ = get_tmp (ΔŨ_cache, ΔŨ1), get_tmp (Ŷ_cache, ΔŨ1)
366350 x̂, x̂next = get_tmp (x̂_cache, ΔŨ1), get_tmp (x̂next_cache, ΔŨ1)
367351 u, û = get_tmp (u_cache, ΔŨ1), get_tmp (û_cache, ΔŨ1)
352+ ΔŨ .= ΔŨtup
368353 Ŷ, x̂end = predict! (Ŷ, x̂, x̂next, u, û, mpc, model, ΔŨ)
369354 g = con_nonlinprog! (g, mpc, model, x̂end , Ŷ, ΔŨ)
370- last_ΔŨtup_dual = ΔŨtup
371355 end
372- return g[i]
356+ return g[i]:: T
373357 end
374358 gfunc = [(ΔŨ... ) -> gfunc_i (i, ΔŨ) for i in 1 : ng]
375359 (Jfunc, gfunc)
0 commit comments