@@ -147,32 +147,45 @@ class BaseTrace(IBaseTrace):
147147 use different test point that might be with changed variables shapes
148148 """
149149
150- def __init__ (self , name , model = None , vars = None , test_point = None ):
151- self .name = name
152-
150+ def __init__ (
151+ self ,
152+ name = None ,
153+ model = None ,
154+ vars = None ,
155+ test_point = None ,
156+ * ,
157+ fn = None ,
158+ var_shapes = None ,
159+ var_dtypes = None ,
160+ ):
153161 model = modelcontext (model )
154- self . model = model
162+
155163 if vars is None :
156164 vars = model .unobserved_value_vars
157165
158166 unnamed_vars = {var for var in vars if var .name is None }
159167 if unnamed_vars :
160168 raise Exception (f"Can't trace unnamed variables: { unnamed_vars } " )
161- self . vars = vars
162- self . varnames = [ var . name for var in vars ]
163- self . fn = model .compile_fn (vars , inputs = model .value_vars , on_unused_input = "ignore" )
169+
170+ if fn is None :
171+ fn = model .compile_fn (vars , inputs = model .value_vars , on_unused_input = "ignore" )
164172
165173 # Get variable shapes. Most backends will need this
166174 # information.
167- if test_point is None :
168- test_point = model .initial_point ()
169- else :
170- test_point_ = model .initial_point ().copy ()
171- test_point_ .update (test_point )
172- test_point = test_point_
173- var_values = list (zip (self .varnames , self .fn (test_point )))
174- self .var_shapes = {var : value .shape for var , value in var_values }
175- self .var_dtypes = {var : value .dtype for var , value in var_values }
175+ if var_shapes is None or var_dtypes is None :
176+ if test_point is None :
177+ test_point = model .initial_point ()
178+ var_values = tuple (zip (vars , fn (** test_point )))
179+ var_shapes = {var .name : value .shape for var , value in var_values }
180+ var_dtypes = {var .name : value .dtype for var , value in var_values }
181+
182+ self .name = name
183+ self .model = model
184+ self .fn = fn
185+ self .vars = vars
186+ self .varnames = [var .name for var in vars ]
187+ self .var_shapes = var_shapes
188+ self .var_dtypes = var_dtypes
176189 self .chain = None
177190 self ._is_base_setup = False
178191 self .sampler_vars = None
0 commit comments