6161 gradient ,
6262 hessian ,
6363 inputvars ,
64+ join_nonshared_inputs ,
6465 rewrite_pregrad ,
6566)
6667from pymc .util import (
@@ -172,6 +173,9 @@ def __init__(
172173 dtype = None ,
173174 casting = "no" ,
174175 compute_grads = True ,
176+ model = None ,
177+ initial_point = None ,
178+ ravel_inputs : bool | None = None ,
175179 ** kwargs ,
176180 ):
177181 if extra_vars_and_values is None :
@@ -219,9 +223,7 @@ def __init__(
219223 givens = []
220224 self ._extra_vars_shared = {}
221225 for var , value in extra_vars_and_values .items ():
222- shared = pytensor .shared (
223- value , var .name + "_shared__" , shape = [1 if s == 1 else None for s in value .shape ]
224- )
226+ shared = pytensor .shared (value , var .name + "_shared__" , shape = value .shape )
225227 self ._extra_vars_shared [var .name ] = shared
226228 givens .append ((var , shared ))
227229
@@ -231,13 +233,28 @@ def __init__(
231233 grads = pytensor .grad (cost , grad_vars , disconnected_inputs = "ignore" )
232234 for grad_wrt , var in zip (grads , grad_vars ):
233235 grad_wrt .name = f"{ var .name } _grad"
234- outputs = [cost , * grads ]
236+ grads = pt .join (0 , * [pt .atleast_1d (grad .ravel ()) for grad in grads ])
237+ outputs = [cost , grads ]
235238 else :
236239 outputs = [cost ]
237240
238- inputs = grad_vars
241+ if ravel_inputs :
242+ if initial_point is None :
243+ initial_point = modelcontext (model ).initial_point ()
244+ outputs , raveled_grad_vars = join_nonshared_inputs (
245+ point = initial_point , inputs = grad_vars , outputs = outputs , make_inputs_shared = False
246+ )
247+ inputs = [raveled_grad_vars ]
248+ else :
249+ if ravel_inputs is None :
250+ warnings .warn (
251+ "ValueGradFunction will become a function of raveled inputs.\n "
252+ "Specify `ravel_inputs` to suppress this warning. Note that setting `ravel_inputs=False` will be forbidden in a future release."
253+ )
254+ inputs = grad_vars
239255
240256 self ._pytensor_function = compile_pymc (inputs , outputs , givens = givens , ** kwargs )
257+ self ._raveled_inputs = ravel_inputs
241258
242259 def set_weights (self , values ):
243260 if values .shape != (self ._n_costs - 1 ,):
@@ -247,38 +264,29 @@ def set_weights(self, values):
247264 def set_extra_values (self , extra_vars ):
248265 self ._extra_are_set = True
249266 for var in self ._extra_vars :
250- self ._extra_vars_shared [var .name ].set_value (extra_vars [var .name ])
267+ self ._extra_vars_shared [var .name ].set_value (extra_vars [var .name ], borrow = True )
251268
252269 def get_extra_values (self ):
253270 if not self ._extra_are_set :
254271 raise ValueError ("Extra values are not set." )
255272
256273 return {var .name : self ._extra_vars_shared [var .name ].get_value () for var in self ._extra_vars }
257274
258- def __call__ (self , grad_vars , grad_out = None , extra_vars = None ):
275+ def __call__ (self , grad_vars , * , extra_vars = None ):
259276 if extra_vars is not None :
260277 self .set_extra_values (extra_vars )
261-
262- if not self ._extra_are_set :
278+ elif not self ._extra_are_set :
263279 raise ValueError ("Extra values are not set." )
264280
265281 if isinstance (grad_vars , RaveledVars ):
266- grad_vars = list (DictToArrayBijection .rmap (grad_vars ).values ())
267-
268- cost , * grads = self ._pytensor_function (* grad_vars )
269-
270- if grads :
271- grads_raveled = DictToArrayBijection .map (
272- {v .name : gv for v , gv in zip (self ._grad_vars , grads )}
273- )
274-
275- if grad_out is None :
276- return cost , grads_raveled .data
282+ if self ._raveled_inputs :
283+ grad_vars = (grad_vars .data ,)
277284 else :
278- np .copyto (grad_out , grads_raveled .data )
279- return cost
280- else :
281- return cost
285+ grad_vars = DictToArrayBijection .rmap (grad_vars ).values ()
286+ elif self ._raveled_inputs and not isinstance (grad_vars , Sequence ):
287+ grad_vars = (grad_vars ,)
288+
289+ return self ._pytensor_function (* grad_vars )
282290
283291 @property
284292 def profile (self ):
@@ -521,7 +529,14 @@ def root(self):
521529 def isroot (self ):
522530 return self .parent is None
523531
524- def logp_dlogp_function (self , grad_vars = None , tempered = False , ** kwargs ):
532+ def logp_dlogp_function (
533+ self ,
534+ grad_vars = None ,
535+ tempered = False ,
536+ initial_point = None ,
537+ ravel_inputs : bool | None = None ,
538+ ** kwargs ,
539+ ):
525540 """Compile a PyTensor function that computes logp and gradient.
526541
527542 Parameters
@@ -547,13 +562,22 @@ def logp_dlogp_function(self, grad_vars=None, tempered=False, **kwargs):
547562 costs = [self .logp ()]
548563
549564 input_vars = {i for i in graph_inputs (costs ) if not isinstance (i , Constant )}
550- ip = self .initial_point (0 )
565+ if initial_point is None :
566+ initial_point = self .initial_point (0 )
551567 extra_vars_and_values = {
552- var : ip [var .name ]
568+ var : initial_point [var .name ]
553569 for var in self .value_vars
554570 if var in input_vars and var not in grad_vars
555571 }
556- return ValueGradFunction (costs , grad_vars , extra_vars_and_values , ** kwargs )
572+ return ValueGradFunction (
573+ costs ,
574+ grad_vars ,
575+ extra_vars_and_values ,
576+ model = self ,
577+ initial_point = initial_point ,
578+ ravel_inputs = ravel_inputs ,
579+ ** kwargs ,
580+ )
557581
558582 def compile_logp (
559583 self ,
0 commit comments