@@ -2176,7 +2176,7 @@ For more details on this argument, see the ODEFunction documentation.
21762176The fields of the ControlFunction type directly match the names of the inputs.
21772177"""
21782178struct ControlFunction{iip, specialize, F, TMM, Ta, Tt, TJ, CTJ, JVP, VJP,
2179- JP, CJP, SP, TPJ, O, TCV, CTCV,
2179+ JP, CJP, SP, TW, TWt, WP, TPJ, O, TCV, CTCV,
21802180 SYS, ID} <: AbstractControlFunction{iip}
21812181 f:: F
21822182 mass_matrix:: TMM
@@ -2189,10 +2189,12 @@ struct ControlFunction{iip, specialize, F, TMM, Ta, Tt, TJ, CTJ, JVP, VJP,
21892189 jac_prototype:: JP
21902190 controljac_prototype:: CJP
21912191 sparsity:: SP
2192+ Wfact:: TW
2193+ Wfact_t:: TWt
2194+ W_prototype:: WP
21922195 paramjac:: TPJ
21932196 observed:: O
21942197 colorvec:: TCV
2195- controlcolorvec:: CTCV
21962198 sys:: SYS
21972199 initialization_data:: ID
21982200end
@@ -4698,6 +4700,146 @@ function BatchIntegralFunction(f, integrand_prototype; kwargs...)
46984700 BatchIntegralFunction {calculated_iip} (f, integrand_prototype; kwargs... )
46994701end
47004702
4703+ function ControlFunction {iip, specialize} (f;
4704+ mass_matrix = __has_mass_matrix (f) ? f. mass_matrix :
4705+ I,
4706+ analytic = __has_analytic (f) ? f. analytic : nothing ,
4707+ tgrad = __has_tgrad (f) ? f. tgrad : nothing ,
4708+ jac = __has_jac (f) ? f. jac : nothing ,
4709+ controljac = __has_controljac (f) ? f. controljac : nothing ,
4710+ jvp = __has_jvp (f) ? f. jvp : nothing ,
4711+ vjp = __has_vjp (f) ? f. vjp : nothing ,
4712+ jac_prototype = __has_jac_prototype (f) ?
4713+ f. jac_prototype :
4714+ nothing ,
4715+ controljac_prototype = __has_controljac_prototype (f) ?
4716+ f. controljac_prototype :
4717+ nothing ,
4718+ sparsity = __has_sparsity (f) ? f. sparsity :
4719+ jac_prototype,
4720+ Wfact = __has_Wfact (f) ? f. Wfact : nothing ,
4721+ Wfact_t = __has_Wfact_t (f) ? f. Wfact_t : nothing ,
4722+ W_prototype = __has_W_prototype (f) ? f. W_prototype : nothing ,
4723+ paramjac = __has_paramjac (f) ? f. paramjac : nothing ,
4724+ observed = __has_observed (f) ? f. observed :
4725+ DEFAULT_OBSERVED,
4726+ colorvec = __has_colorvec (f) ? f. colorvec : nothing ,
4727+ sys = __has_sys (f) ? f. sys : nothing ,
4728+ initializeprob = __has_initializeprob (f) ? f. initializeprob : nothing ,
4729+ update_initializeprob! = __has_update_initializeprob! (f) ?
4730+ f. update_initializeprob! : nothing ,
4731+ initializeprobmap = __has_initializeprobmap (f) ? f. initializeprobmap : nothing ,
4732+ initializeprobpmap = __has_initializeprobpmap (f) ? f. initializeprobpmap : nothing ,
4733+ initialization_data = __has_initialization_data (f) ? f. initialization_data :
4734+ nothing ,
4735+ nlprob_data = __has_nlprob_data (f) ? f. nlprob_data : nothing
4736+ ) where {iip,
4737+ specialize
4738+ }
4739+ if mass_matrix === I && f isa Tuple
4740+ mass_matrix = ((I for i in 1 : length (f)). .. ,)
4741+ end
4742+
4743+ if (specialize === FunctionWrapperSpecialize) &&
4744+ ! (f isa FunctionWrappersWrappers. FunctionWrappersWrapper)
4745+ error (" FunctionWrapperSpecialize must be used on the problem constructor for access to u0, p, and t types!" )
4746+ end
4747+
4748+ if jac === nothing && isa (jac_prototype, AbstractSciMLOperator)
4749+ if iip
4750+ jac = update_coefficients! # (J,u,p,t)
4751+ else
4752+ jac = (u, p, t) -> update_coefficients (deepcopy (jac_prototype), u, p, t)
4753+ end
4754+ end
4755+
4756+ if controljac === nothing && isa (controljac_prototype, AbstractSciMLOperator)
4757+ if iip_bc
4758+ controljac = update_coefficients! # (J,u,p,t)
4759+ else
4760+ controljac = (u, p, t) -> update_coefficients! (deepcopy (controljac_prototype), u, p, t)
4761+ end
4762+ end
4763+
4764+ if jac_prototype != = nothing && colorvec === nothing &&
4765+ ArrayInterface. fast_matrix_colors (jac_prototype)
4766+ _colorvec = ArrayInterface. matrix_colors (jac_prototype)
4767+ else
4768+ _colorvec = colorvec
4769+ end
4770+
4771+ jaciip = jac != = nothing ? isinplace (jac, 4 , " jac" , iip) : iip
4772+ controljaciip = controljac != = nothing ? isinplace (controljac, 4 , " controljac" , iip) : iip
4773+ tgradiip = tgrad != = nothing ? isinplace (tgrad, 4 , " tgrad" , iip) : iip
4774+ jvpiip = jvp != = nothing ? isinplace (jvp, 5 , " jvp" , iip) : iip
4775+ vjpiip = vjp != = nothing ? isinplace (vjp, 5 , " vjp" , iip) : iip
4776+ Wfactiip = Wfact != = nothing ? isinplace (Wfact, 5 , " Wfact" , iip) : iip
4777+ Wfact_tiip = Wfact_t != = nothing ? isinplace (Wfact_t, 5 , " Wfact_t" , iip) : iip
4778+ paramjaciip = paramjac != = nothing ? isinplace (paramjac, 4 , " paramjac" , iip) : iip
4779+
4780+ nonconforming = (jaciip, tgradiip, jvpiip, vjpiip, Wfactiip, Wfact_tiip,
4781+ paramjaciip) .!= iip
4782+ if any (nonconforming)
4783+ nonconforming = findall (nonconforming)
4784+ functions = [" jac" , " tgrad" , " jvp" , " vjp" , " Wfact" , " Wfact_t" , " paramjac" ][nonconforming]
4785+ throw (NonconformingFunctionsError (functions))
4786+ end
4787+
4788+ _f = prepare_function (f)
4789+
4790+ sys = sys_or_symbolcache (sys, syms, paramsyms, indepsym)
4791+ initdata = reconstruct_initialization_data (
4792+ initialization_data, initializeprob, update_initializeprob!,
4793+ initializeprobmap, initializeprobpmap)
4794+
4795+ if specialize === NoSpecialize
4796+ ControlFunction{iip, specialize,
4797+ Any, Any, Any, Any,
4798+ Any, Any, Any, Any, typeof (jac_prototype), typeof (controljac_prototype),
4799+ typeof (sparsity), Any, Any, typeof (W_prototype), Any,
4800+ Any,
4801+ typeof (_colorvec),
4802+ typeof (sys), Union{Nothing, OverrideInitData}}(
4803+ _f, mass_matrix, analytic, tgrad, jac, controljac,
4804+ jvp, vjp, jac_prototype, controljac_prototype, sparsity, Wfact,
4805+ Wfact_t, W_prototype, paramjac,
4806+ observed, _colorvec, sys, initdata)
4807+ elseif specialize === false
4808+ ControlFunction{iip, FunctionWrapperSpecialize,
4809+ typeof (_f), typeof (mass_matrix), typeof (analytic), typeof (tgrad),
4810+ typeof (jac), typeof (controljac), typeof (jvp), typeof (vjp), typeof (jac_prototype), typeof (controljac_prototype),
4811+ typeof (sparsity), typeof (Wfact), typeof (Wfact_t), typeof (W_prototype),
4812+ typeof (paramjac),
4813+ typeof (observed),
4814+ typeof (_colorvec),
4815+ typeof (sys), typeof (initdata)}(_f, mass_matrix,
4816+ analytic, tgrad, jac, controljac,
4817+ jvp, vjp, jac_prototype, controljac_prototype, sparsity, Wfact,
4818+ Wfact_t, W_prototype, paramjac,
4819+ observed, _colorvec, sys, initdata)
4820+ else
4821+ ControlFunction{iip, specialize,
4822+ typeof (_f), typeof (mass_matrix), typeof (analytic), typeof (tgrad),
4823+ typeof (jac), typeof (controljac), typeof (jvp), typeof (vjp), typeof (jac_prototype), typeof (controljac_prototype),
4824+ typeof (sparsity), typeof (Wfact), typeof (Wfact_t), typeof (W_prototype),
4825+ typeof (paramjac),
4826+ typeof (observed),
4827+ typeof (_colorvec),
4828+ typeof (sys), typeof (initdata)}(
4829+ _f, mass_matrix, analytic, tgrad,
4830+ jac, controljac, jvp, vjp, jac_prototype, controljac_prototype, sparsity, Wfact,
4831+ Wfact_t, W_prototype, paramjac,
4832+ observed, _colorvec, sys, initdata)
4833+ end
4834+ end
4835+
4836+ function ODEFunction {iip} (f; kwargs... ) where {iip}
4837+ ODEFunction {iip, FullSpecialize} (f; kwargs... )
4838+ end
4839+ ODEFunction {iip} (f:: ODEFunction ; kwargs... ) where {iip} = f
4840+ ODEFunction (f; kwargs... ) = ODEFunction {isinplace(f, 4), FullSpecialize} (f; kwargs... )
4841+ ODEFunction (f:: ODEFunction ; kwargs... ) = f
4842+
47014843# ######### Utility functions
47024844
47034845function sys_or_symbolcache (sys, syms, paramsyms, indepsym = nothing )
@@ -4731,6 +4873,7 @@ __has_Wfact_t(f) = isdefined(f, :Wfact_t)
47314873__has_W_prototype (f) = isdefined (f, :W_prototype )
47324874__has_paramjac (f) = isdefined (f, :paramjac )
47334875__has_jac_prototype (f) = isdefined (f, :jac_prototype )
4876+ __has_controljac_prototype (f) = isdefined (f, :controljac_prototype )
47344877__has_sparsity (f) = isdefined (f, :sparsity )
47354878__has_mass_matrix (f) = isdefined (f, :mass_matrix )
47364879__has_syms (f) = isdefined (f, :syms )
0 commit comments