Skip to content

Commit 21bd589

Browse files
committed
in-place NonLinModel seems to work
1 parent 50c920e commit 21bd589

File tree

7 files changed

+109
-53
lines changed

7 files changed

+109
-53
lines changed

src/controller/execute.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -325,15 +325,18 @@ function predict!(
325325
nu, ny, nd, Hp, Hc = model.nu, model.ny, model.nd, mpc.Hp, mpc.Hc
326326
u0 = u
327327
x̂ .= mpc.estim.
328+
x̂next = similar(x̂) # TODO: avoid this allocation if possible
328329
u0 .= mpc.estim.lastu0
329330
d0 = @views mpc.d0[1:end]
330331
for j=1:Hp
331332
if j Hc
332333
u0 .+= @views ΔŨ[(1 + nu*(j-1)):(nu*j)]
333334
end
334-
x̂[:] = (mpc.estim, model, x̂, u0, d0)
335-
d0 = @views mpc.D̂0[(1 + nd*(j-1)):(nd*j)]
336-
Ŷ[(1 + ny*(j-1)):(ny*j)] = (mpc.estim, model, x̂, d0)
335+
f̂!(x̂next, mpc.estim, model, x̂, u0, d0)
336+
x̂ .= x̂next
337+
d0 = @views mpc.D̂0[(1 + nd*(j-1)):(nd*j)]
338+
= @views Ŷ[(1 + ny*(j-1)):(ny*j)]
339+
ĥ!(ŷ, mpc.estim, model, x̂, d0)
337340
end
338341
Ŷ .=.+ mpc.Ŷop # Ŷop = Ŷs + Yop, and Ŷs=0 if mpc.estim is not an InternalModel
339342
end =

src/estimator/execute.jl

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ function remove_op!(estim::StateEstimator, u, ym, d)
1515
end
1616

1717
@doc raw"""
18-
(estim::StateEstimator, model::SimModel, x̂, u, d)
18+
!(x̂next, estim::StateEstimator, model::SimModel, x̂, u, d) -> x̂next
1919
20-
State function ``\mathbf{f̂}`` of the augmented model.
20+
Mutating state function ``\mathbf{f̂}`` of the augmented model.
2121
2222
By introducing an augmented state vector ``\mathbf{x̂}`` like in [`augment_model`](@ref), the
2323
function returns the next state of the augmented model, defined as:
@@ -26,28 +26,48 @@ function returns the next state of the augmented model, defined as:
2626
\mathbf{x̂}(k+1) &= \mathbf{f̂}\Big(\mathbf{x̂}(k), \mathbf{u}(k), \mathbf{d}(k)\Big) \\
2727
\mathbf{ŷ}(k) &= \mathbf{ĥ}\Big(\mathbf{x̂}(k), \mathbf{d}(k)\Big)
2828
\end{aligned}
29+
where ``\mathbf{x̂}(k+1)`` is stored in `x̂next` argument.
2930
```
3031
"""
31-
function (estim::StateEstimator, model::SimModel, x̂, u, d)
32+
function !(x̂next, estim::StateEstimator, model::SimModel, x̂, u, d)
3233
# `@views` macro avoid copies with matrix slice operator e.g. [a:b]
3334
@views x̂d, x̂s = x̂[1:model.nx], x̂[model.nx+1:end]
34-
return [f(model, x̂d, u + estim.Cs_u*x̂s, d); estim.As*x̂s]
35+
@views x̂d_next, x̂s_next = x̂next[1:model.nx], x̂next[model.nx+1:end]
36+
T = promote_type(eltype(x̂), eltype(u), eltype(d))
37+
u_us = Vector{T}(undef, model.nu) # TODO: avoid this allocation if possible
38+
u_us .= u .+ mul!(u_us, estim.Cs_u, x̂s)
39+
f!(x̂d_next, model, x̂d, u_us, d)
40+
mul!(x̂s_next, estim.As, x̂s)
41+
return x̂next
3542
end
3643
"Use the augmented model matrices if `model` is a [`LinModel`](@ref)."
37-
(estim::StateEstimator, ::LinModel, x̂, u, d) = estim.*+ estim.B̂u * u + estim.B̂d * d
44+
function f̂!(x̂next, estim::StateEstimator, ::LinModel, x̂, u, d)
45+
x̂next .= 0
46+
mul!(x̂next, estim.Â, x̂, 1, 1)
47+
mul!(x̂next, estim.B̂u, u, 1, 1)
48+
mul!(x̂next, estim.B̂d, d, 1, 1)
49+
return x̂next
50+
end
3851

3952
@doc raw"""
40-
(estim::StateEstimator, model::SimModel, x̂, d)
53+
!(ŷ, estim::StateEstimator, model::SimModel, x̂, d) -> ŷ
4154
42-
Output function ``\mathbf{ĥ}`` of the augmented model, see [`f̂`](@ref) for details.
55+
Mutating output function ``\mathbf{ĥ}`` of the augmented model, see [`f̂!`](@ref).
4356
"""
44-
function (estim::StateEstimator, model::SimModel, x̂, d)
57+
function !(ŷ, estim::StateEstimator, model::SimModel, x̂, d)
4558
# `@views` macro avoid copies with matrix slice operator e.g. [a:b]
4659
@views x̂d, x̂s = x̂[1:model.nx], x̂[model.nx+1:end]
47-
return h(model, x̂d, d) + estim.Cs_y*x̂s
60+
h!(ŷ, model, x̂d, d)
61+
mul!(ŷ, estim.Cs_y, x̂s, 1, 1)
62+
return
4863
end
4964
"Use the augmented model matrices if `model` is a [`LinModel`](@ref)."
50-
(estim::StateEstimator, ::LinModel, x̂, d) = estim.*+ estim.D̂d * d
65+
function ĥ!(ŷ, estim::StateEstimator, ::LinModel, x̂, d)
66+
ŷ .= 0
67+
mul!(ŷ, estim.Ĉ, x̂, 1, 1)
68+
mul!(ŷ, estim.D̂d, d, 1, 1)
69+
return
70+
end
5171

5272

5373
@doc raw"""
@@ -147,8 +167,12 @@ julia> ŷ = evaloutput(kf)
147167
20.0
148168
```
149169
"""
150-
function evaloutput(estim::StateEstimator, d=empty(estim.x̂))
151-
return (estim, estim.model, estim.x̂, d - estim.model.dop) + estim.model.yop
170+
function evaloutput(estim::StateEstimator{NT}, d=empty(estim.x̂)) where NT <: Real
171+
validate_args(estim.model, d)
172+
= Vector{NT}(undef, estim.model.ny)
173+
ĥ!(ŷ, estim, estim.model, estim.x̂, d - estim.model.dop)
174+
.+= estim.model.yop
175+
return
152176
end
153177

154178
"Functor allowing callable `StateEstimator` object as an alias for `evaloutput`."
@@ -186,7 +210,7 @@ updatestate!(::StateEstimator, _ ) = throw(ArgumentError("missing measured outpu
186210
Check `u`, `ym` and `d` sizes against `estim` dimensions.
187211
"""
188212
function validate_args(estim::StateEstimator, u, ym, d)
189-
validate_args(estim.model, u, d)
213+
validate_args(estim.model, d, u)
190214
nym = estim.nym
191215
size(ym) (nym,) && throw(DimensionMismatch("ym size $(size(ym)) ≠ meas. output size ($nym,)"))
192216
end

src/estimator/internal_model.jl

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -144,20 +144,20 @@ function matrices_internalmodel(model::SimModel{NT}) where NT<:Real
144144
end
145145

146146
@doc raw"""
147-
(::InternalModel, model::NonLinModel, x̂, u, d)
147+
!(x̂next, ::InternalModel, model::NonLinModel, x̂, u, d)
148148
149149
State function ``\mathbf{f̂}`` of [`InternalModel`](@ref) for [`NonLinModel`](@ref).
150150
151-
It calls `model.f(x̂, u ,d)` since this estimator does not augment the states.
151+
It calls `model.f!(x̂next, x̂, u ,d)` since this estimator does not augment the states.
152152
"""
153-
(::InternalModel, model::NonLinModel, x̂, u, d) = model.f(x̂, u, d)
153+
!(x̂next, ::InternalModel, model::NonLinModel, x̂, u, d) = model.f!(x̂next, x̂, u, d)
154154

155155
@doc raw"""
156-
(::InternalModel, model::NonLinModel, x̂, d)
156+
!(ŷ, ::InternalModel, model::NonLinModel, x̂, d)
157157
158-
Output function ``\mathbf{ĥ}`` of [`InternalModel`](@ref), it calls `model.h`.
158+
Output function ``\mathbf{ĥ}`` of [`InternalModel`](@ref), it calls `model.h!`.
159159
"""
160-
(::InternalModel, model::NonLinModel, x̂, d) = model.h(x̂, d)
160+
!(x̂next, ::InternalModel, model::NonLinModel, x̂, d) = model.h!(x̂next, x̂, d)
161161

162162

163163
@doc raw"""
@@ -218,12 +218,17 @@ function update_estimate!(
218218
model = estim.model
219219
x̂d, x̂s = estim.x̂d, estim.x̂s
220220
# -------------- deterministic model ---------------------
221-
ŷd = h(model, x̂d, d)
222-
x̂d[:] = f(model, x̂d, u, d) # this also updates estim.x̂ (they are the same object)
221+
ŷd, x̂dnext = Vector{NT}(undef, model.ny), Vector{NT}(undef, model.nx)
222+
h!(ŷd, model, x̂d, d)
223+
f!(x̂dnext, model, x̂d, u, d)
224+
x̂d .= x̂dnext # this also updates estim.x̂ (they are the same object)
223225
# --------------- stochastic model -----------------------
226+
x̂snext = Vector{NT}(undef, estim.nxs)
224227
ŷs = zeros(NT, model.ny)
225228
ŷs[estim.i_ym] = ym - ŷd[estim.i_ym] # ŷs=0 for unmeasured outputs
226-
x̂s[:] = estim.Âs*x̂s + estim.B̂s*ŷs
229+
mul!(x̂snext, estim.Âs, x̂s)
230+
mul!(x̂snext, estim.B̂s, ŷs, 1, 1)
231+
x̂s .= x̂snext
227232
return nothing
228233
end
229234

@@ -247,11 +252,12 @@ This estimator does not augment the state vector, thus ``\mathbf{x̂ = x̂_d}``.
247252
"""
248253
function init_estimate!(estim::InternalModel, model::LinModel{NT}, u, ym, d) where NT<:Real
249254
x̂d, x̂s = estim.x̂d, estim.x̂s
250-
x̂d[:] = (I - model.A)\(model.Bu*u + model.Bd*d)
251-
ŷd = h(model, x̂d, d)
255+
x̂d .= (I - model.A)\(model.Bu*u + model.Bd*d)
256+
ŷd = Vector{NT}(undef, model.ny)
257+
h!(ŷd, model, x̂d, d)
252258
ŷs = zeros(NT, model.ny)
253259
ŷs[estim.i_ym] = ym - ŷd[estim.i_ym] # ŷs=0 for unmeasured outputs
254-
x̂s[:] = (I-estim.Âs)\estim.B̂s*ŷs
260+
x̂s .= (I-estim.Âs)\estim.B̂s*ŷs
255261
return nothing
256262
end
257263

src/estimator/kalman.jl

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -569,6 +569,7 @@ function update_estimate!(estim::UnscentedKalmanFilter{NT}, u, ym, d) where NT<:
569569
# --- initialize matrices ---
570570
= Matrix{NT}(undef, nx̂, nσ)
571571
ŷm = Vector{NT}(undef, nym)
572+
= Vector{NT}(undef, estim.model.ny)
572573
Ŷm = Matrix{NT}(undef, nym, nσ)
573574
sqrt_P̂ = LowerTriangular{NT, Matrix{NT}}(Matrix{NT}(undef, nx̂, nx̂))
574575
# --- correction step ---
@@ -578,7 +579,8 @@ function update_estimate!(estim::UnscentedKalmanFilter{NT}, u, ym, d) where NT<:
578579
X̂[:, 2:nx̂+1] .+= γ_sqrt_P̂
579580
X̂[:, nx̂+2:end] .-= γ_sqrt_P̂
580581
for j in axes(Ŷm, 2)
581-
Ŷm[:, j] = @views (estim, estim.model, X̂[:, j], d)[estim.i_ym]
582+
@views ĥ!(ŷ, estim, estim.model, X̂[:, j], d)
583+
@views Ŷm[:, j] .= ŷ[estim.i_ym]
582584
end
583585
mul!(ŷm, Ŷm, m̂)
584586
X̄, Ȳm = X̂, Ŷm
@@ -600,7 +602,7 @@ function update_estimate!(estim::UnscentedKalmanFilter{NT}, u, ym, d) where NT<:
600602
X̂_cor[:, nx̂+2:end] .-= γ_sqrt_P̂_cor
601603
X̂_next = X̂_cor
602604
for j in axes(X̂_next, 2)
603-
X̂_next[:, j] = @views (estim, estim.model, X̂_cor[:, j], u, d)
605+
@views f̂!(X̂_next[:, j], estim, estim.model, X̂_cor[:, j], u, d)
604606
end
605607
x̂_next = mul!(x̂, X̂_next, m̂)
606608
X̄_next = X̂_next
@@ -757,9 +759,13 @@ automatically computes the Jacobians:
757759
```
758760
The matrix ``\mathbf{Ĥ^m}`` is the rows of ``\mathbf{Ĥ}`` that are measured outputs.
759761
"""
760-
function update_estimate!(estim::ExtendedKalmanFilter, u, ym, d=empty(estim.x̂))
761-
= ForwardDiff.jacobian(x̂ -> (estim, estim.model, x̂, u, d), estim.x̂)
762-
= ForwardDiff.jacobian(x̂ -> (estim, estim.model, x̂, d), estim.x̂)
762+
function update_estimate!(
763+
estim::ExtendedKalmanFilter{NT}, u, ym, d=empty(estim.x̂)
764+
) where NT<:Real
765+
model = estim.model
766+
x̂next, ŷ = Vector{NT}(undef, estim.nx̂), Vector{NT}(undef, model.ny)
767+
= ForwardDiff.jacobian((x̂next, x̂) -> f̂!(x̂next, estim, model, x̂, u, d), x̂next, estim.x̂)
768+
= ForwardDiff.jacobian((ŷ, x̂) -> ĥ!(ŷ, estim, model, x̂, d), ŷ, estim.x̂)
763769
return update_estimate_kf!(estim, u, ym, d, F̂, Ĥ[estim.i_ym, :], estim.P̂, estim.x̂)
764770
end
765771

@@ -790,7 +796,7 @@ function validate_kfcov(nym, nx̂, Q̂, R̂, P̂0=nothing)
790796
end
791797

792798
"""
793-
update_estimate_kf!(estim, u, ym, d, Â, Ĉm, P̂, x̂=nothing)
799+
update_estimate_kf!(estim::StateEstimator, u, ym, d, Â, Ĉm, P̂, x̂=nothing)
794800
795801
Update time-varying/extended Kalman Filter estimates with augmented `Â` and `Ĉm` matrices.
796802
@@ -801,16 +807,22 @@ The implementation uses in-place operations and explicit factorization to reduce
801807
allocations. See e.g. [`KalmanFilter`](@ref) docstring for the equations. If `isnothing(x̂)`,
802808
only the covariance `P̂` is updated.
803809
"""
804-
function update_estimate_kf!(estim, u, ym, d, Â, Ĉm, P̂, x̂=nothing)
810+
function update_estimate_kf!(
811+
estim::StateEstimator{NT}, u, ym, d, Â, Ĉm, P̂, x̂=nothing
812+
) where NT<:Real
805813
Q̂, R̂, M̂ = estim.Q̂, estim.R̂, estim.
806814
mul!(M̂, P̂, Ĉm')
807815
rdiv!(M̂, cholesky!(Hermitian(Ĉm ** Ĉm' .+ R̂)))
808816
if !isnothing(x̂)
809817
mul!(estim.K̂, Â, M̂)
810-
ŷm = @views (estim, estim.model, x̂, d)[estim.i_ym]
818+
x̂next, ŷ = Vector{NT}(undef, estim.nx̂), Vector{NT}(undef, estim.model.ny)
819+
ĥ!(ŷ, estim, estim.model, x̂, d)
820+
ŷm = @views ŷ[estim.i_ym]
811821
= ŷm
812822
v̂ .= ym .- ŷm
813-
x̂ .= (estim, estim.model, x̂, u, d) .+ mul!(x̂, estim.K̂, v̂)
823+
f̂!(x̂next, estim, estim.model, x̂, u, d)
824+
mul!(x̂next, estim.K̂, v̂, 1, 1)
825+
estim.x̂ .= x̂next
814826
end
815827
.data .=* (P̂ .-* Ĉm * P̂) *' .+# .data is necessary for Hermitians
816828
return nothing

src/model/linmodel.jl

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -217,23 +217,34 @@ disturbances ``\mathbf{d_0 = d - d_{op}}``. The Moore-Penrose pseudo-inverse com
217217
``\mathbf{(I - A)^{-1}}`` to support integrating `model` (integrator states will be 0).
218218
"""
219219
function steadystate!(model::LinModel, u, d)
220-
model.x[:] = pinv(I - model.A)*(model.Bu*(u - model.uop) + model.Bd*(d - model.dop))
220+
model.x .= pinv(I - model.A)*(model.Bu*(u - model.uop) + model.Bd*(d - model.dop))
221221
return nothing
222222
end
223223

224224
"""
225-
f(model::LinModel, x, u, d)
225+
f!(xnext, model::LinModel, x, u, d) -> xnext
226226
227-
Evaluate ``\\mathbf{A x + B_u u + B_d d}`` when `model` is a [`LinModel`](@ref).
227+
Evaluate `xnext = A*x + Bu*u + Bd*d` in-place when `model` is a [`LinModel`](@ref).
228228
"""
229-
f(model::LinModel, x, u, d) = model.A * x + model.Bu * u + model.Bd * d
229+
function f!(xnext, model::LinModel, x, u, d)
230+
xnext .= 0
231+
mul!(xnext, model.A, x, 1, 1)
232+
mul!(xnext, model.Bu, u, 1, 1)
233+
mul!(xnext, model.Bd, d, 1, 1)
234+
return xnext
235+
end
230236

231237

232238
"""
233-
h(model::LinModel, x, d)
239+
h!(y, model::LinModel, x, d) -> y
234240
235-
Evaluate ``\\mathbf{C x + D_d d}`` when `model` is a [`LinModel`](@ref).
241+
Evaluate `y = C*x + Dd*d` in-place when `model` is a [`LinModel`](@ref).
236242
"""
237-
h(model::LinModel, x, d) = model.C*x + model.Dd*d
243+
function h!(y, model::LinModel, x, d)
244+
y .= 0
245+
mul!(y, model.C, x, 1, 1)
246+
mul!(y, model.Dd, d, 1, 1)
247+
return y
248+
end
238249

239250
typestr(model::LinModel) = "linear"

src/model/nonlinmodel.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -122,13 +122,10 @@ end
122122
steadystate!(::SimModel, _ , _ ) = nothing
123123

124124

125-
126-
127-
128-
"Call ``\\mathbf{f!(x, u, d)}`` with `model.f!` function for [`NonLinModel`](@ref)."
125+
"Call `f!(xnext, x, u, d)` with `model.f!` method for [`NonLinModel`](@ref)."
129126
f!(xnext, model::NonLinModel, x, u, d) = model.f!(xnext, x, u, d)
130127

131-
"Call ``\\mathbf{h!(y, x, d)}`` with `model.h` function for [`NonLinModel`](@ref)."
128+
"Call `h!(y, x, d)` with `model.h` method for [`NonLinModel`](@ref)."
132129
h!(y, model::NonLinModel, x, d) = model.h!(y, x, d)
133130

134131
typestr(model::NonLinModel) = "nonlinear"

src/sim_model.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ true
123123
124124
"""
125125
function initstate!(model::SimModel, u, d=empty(model.x))
126-
validate_args(model::SimModel, u, d)
126+
validate_args(model::SimModel, d, u)
127127
steadystate!(model, u, d)
128128
return model.x
129129
end
@@ -142,9 +142,11 @@ julia> x = updatestate!(model, [1])
142142
1.0
143143
```
144144
"""
145-
function updatestate!(model::SimModel, u, d=empty(model.x))
145+
function updatestate!(model::SimModel{NT}, u, d=empty(model.x)) where NT <: Real
146146
validate_args(model::SimModel, d, u)
147-
f!(model.x, model, model.x, u .- model.uop, d .- model.dop)
147+
xnext = Vector{NT}(undef, model.nx)
148+
f!(xnext, model, model.x, u - model.uop, d - model.dop)
149+
model.x .= xnext
148150
return model.x
149151
end
150152

@@ -167,7 +169,8 @@ julia> y = evaloutput(model)
167169
function evaloutput(model::SimModel{NT}, d=empty(model.x)) where NT <: Real
168170
validate_args(model, d)
169171
y = Vector{NT}(undef, model.ny)
170-
h!(y, model, model.x, d .- model.dop) .+ model.yop
172+
h!(y, model, model.x, d - model.dop)
173+
y .+= model.yop
171174
return y
172175
end
173176

0 commit comments

Comments
 (0)