@@ -1747,7 +1747,7 @@ def eval_rv_shapes(self) -> dict[str, tuple[int, ...]]:
17471747 )
17481748 return {name : tuple (shape ) for name , shape in zip (names , f ())}
17491749
1750- def check_start_vals (self , start ):
1750+ def check_start_vals (self , start , ** kwargs ):
17511751 r"""Check that the starting values for MCMC do not cause the relevant log probability
17521752 to evaluate to something invalid (e.g. Inf or NaN)
17531753
@@ -1758,6 +1758,8 @@ def check_start_vals(self, start):
17581758 Defaults to ``trace.point(-1))`` if there is a trace provided and
17591759 ``model.initial_point`` if not (defaults to empty dict). Initialization
17601760 methods for NUTS (see ``init`` keyword) can overwrite the default.
1761+ Other keyword arguments :
1762+ Any other keyword argument is sent to :py:meth:`~pymc.model.core.Model.point_logps`.
17611763
17621764 Raises
17631765 ------
@@ -1787,7 +1789,7 @@ def check_start_vals(self, start):
17871789 f"Valid keys are: { valid_keys } , but { extra_keys } was supplied"
17881790 )
17891791
1790- initial_eval = self .point_logps (point = elem )
1792+ initial_eval = self .point_logps (point = elem , ** kwargs )
17911793
17921794 if not all (np .isfinite (v ) for v in initial_eval .values ()):
17931795 raise SamplingError (
@@ -1797,7 +1799,7 @@ def check_start_vals(self, start):
17971799 "You can call `model.debug()` for more details."
17981800 )
17991801
1800- def point_logps (self , point = None , round_vals = 2 ):
1802+ def point_logps (self , point = None , round_vals = 2 , ** kwargs ):
18011803 """Computes the log probability of `point` for all random variables in the model.
18021804
18031805 Parameters
@@ -1807,6 +1809,8 @@ def point_logps(self, point=None, round_vals=2):
18071809 is used.
18081810 round_vals : int, default 2
18091811 Number of decimals to round log-probabilities.
1812+ Other keyword arguments :
1813+ Any other keyword argument are sent provided to :py:meth:`~pymc.model.core.Model.compile_fn`
18101814
18111815 Returns
18121816 -------
@@ -1822,7 +1826,7 @@ def point_logps(self, point=None, round_vals=2):
18221826 factor .name : np .round (np .asarray (factor_logp ), round_vals )
18231827 for factor , factor_logp in zip (
18241828 factors ,
1825- self .compile_fn (factor_logps_fn )(point ),
1829+ self .compile_fn (factor_logps_fn , ** kwargs )(point ),
18261830 )
18271831 }
18281832
0 commit comments