3333from pytensor .graph .utils import MetaType
3434from pytensor .scan .op import Scan
3535from pytensor .tensor .basic import as_tensor_variable
36+ from pytensor .tensor .blockwise import safe_signature
3637from pytensor .tensor .random .op import RandomVariable
3738from pytensor .tensor .random .rewriting import local_subtensor_rv_lift
3839from pytensor .tensor .random .type import RandomGeneratorType , RandomType
3940from pytensor .tensor .random .utils import normalize_size_param
4041from pytensor .tensor .rewriting .shape import ShapeFeature
42+ from pytensor .tensor .utils import _parse_gufunc_signature
4143from pytensor .tensor .variable import TensorVariable
4244from typing_extensions import TypeAlias
4345
@@ -261,6 +263,12 @@ class SymbolicRandomVariable(OpFromGraph):
261263 (0 for scalar, 1 for vector, ...)
262264 """
263265
266+ ndims_params : Optional [Sequence [int ]] = None
267+ """Number of core dimensions of the distribution's parameters."""
268+
269+ signature : str = None
270+ """Numpy-like vectorized signature of the distribution."""
271+
264272 inline_logprob : bool = False
265273 """Specifies whether the logprob function is derived automatically by introspection
266274 of the inner graph.
@@ -271,9 +279,25 @@ class SymbolicRandomVariable(OpFromGraph):
271279 _print_name : tuple [str , str ] = ("Unknown" , "\\ operatorname{Unknown}" )
272280 """Tuple of (name, latex name) used for for pretty-printing variables of this type"""
273281
274- def __init__ (self , * args , ndim_supp , ** kwargs ):
275- """Initialitze a SymbolicRandomVariable class."""
276- self .ndim_supp = ndim_supp
282+ def __init__ (
283+ self ,
284+ * args ,
285+ ** kwargs ,
286+ ):
287+ """Initialize a SymbolicRandomVariable class."""
288+ if self .signature is None :
289+ self .signature = kwargs .get ("signature" , None )
290+
291+ if self .signature is not None :
292+ inputs_sig , outputs_sig = _parse_gufunc_signature (self .signature )
293+ self .ndims_params = [len (sig ) for sig in inputs_sig ]
294+ self .ndim_supp = max (len (out_sig ) for out_sig in outputs_sig )
295+
296+ if self .ndim_supp is None :
297+ self .ndim_supp = kwargs .get ("ndim_supp" , None )
298+ if self .ndim_supp is None :
299+ raise ValueError ("ndim_supp or gufunc_signature must be provided" )
300+
277301 kwargs .setdefault ("inline" , True )
278302 super ().__init__ (* args , ** kwargs )
279303
@@ -286,6 +310,11 @@ def update(self, node: Node):
286310 """
287311 return {}
288312
313+ def batch_ndim (self , node : Node ) -> int :
314+ """Number of dimensions of the distribution's batch shape."""
315+ out_ndim = max (getattr (out .type , "ndim" , 0 ) for out in node .outputs )
316+ return out_ndim - self .ndim_supp
317+
289318
290319class Distribution (metaclass = DistributionMeta ):
291320 """Statistical distribution"""
@@ -558,23 +587,29 @@ def dist(
558587 logcdf : Optional [Callable ] = None ,
559588 random : Optional [Callable ] = None ,
560589 support_point : Optional [Callable ] = None ,
561- ndim_supp : int = 0 ,
590+ ndim_supp : Optional [ int ] = None ,
562591 ndims_params : Optional [Sequence [int ]] = None ,
592+ signature : Optional [str ] = None ,
563593 dtype : str = "floatX" ,
564594 class_name : str = "CustomDist" ,
565595 ** kwargs ,
566596 ):
597+ if ndim_supp is None or ndims_params is None :
598+ if signature is None :
599+ ndim_supp = 0
600+ ndims_params = [0 ] * len (dist_params )
601+ else :
602+ inputs , outputs = _parse_gufunc_signature (signature )
603+ ndim_supp = max (len (out ) for out in outputs )
604+ ndims_params = [len (inp ) for inp in inputs ]
605+
567606 if ndim_supp > 0 :
568607 raise NotImplementedError (
569608 "CustomDist with ndim_supp > 0 and without a `dist` function are not supported."
570609 )
571610
572611 dist_params = [as_tensor_variable (param ) for param in dist_params ]
573612
574- # Assume scalar ndims_params
575- if ndims_params is None :
576- ndims_params = [0 ] * len (dist_params )
577-
578613 if logp is None :
579614 logp = default_not_implemented (class_name , "logp" )
580615
@@ -614,7 +649,7 @@ def rv_op(
614649 random : Optional [Callable ],
615650 support_point : Optional [Callable ],
616651 ndim_supp : int ,
617- ndims_params : Optional [ Sequence [int ] ],
652+ ndims_params : Sequence [int ],
618653 dtype : str ,
619654 class_name : str ,
620655 ** kwargs ,
@@ -702,7 +737,9 @@ def dist(
702737 logp : Optional [Callable ] = None ,
703738 logcdf : Optional [Callable ] = None ,
704739 support_point : Optional [Callable ] = None ,
705- ndim_supp : int = 0 ,
740+ ndim_supp : Optional [int ] = None ,
741+ ndims_params : Optional [Sequence [int ]] = None ,
742+ signature : Optional [str ] = None ,
706743 dtype : str = "floatX" ,
707744 class_name : str = "CustomDist" ,
708745 ** kwargs ,
@@ -712,14 +749,24 @@ def dist(
712749 if logcdf is None :
713750 logcdf = default_not_implemented (class_name , "logcdf" )
714751
752+ if signature is None :
753+ if ndim_supp is None :
754+ ndim_supp = 0
755+ if ndims_params is None :
756+ ndims_params = [0 ] * len (dist_params )
757+ signature = safe_signature (
758+ core_inputs = [pt .tensor (shape = (None ,) * ndim_param ) for ndim_param in ndims_params ],
759+ core_outputs = [pt .tensor (shape = (None ,) * ndim_supp )],
760+ )
761+
715762 return super ().dist (
716763 dist_params ,
717764 class_name = class_name ,
718765 logp = logp ,
719766 logcdf = logcdf ,
720767 dist = dist ,
721768 support_point = support_point ,
722- ndim_supp = ndim_supp ,
769+ signature = signature ,
723770 ** kwargs ,
724771 )
725772
@@ -732,7 +779,7 @@ def rv_op(
732779 logcdf : Optional [Callable ],
733780 support_point : Optional [Callable ],
734781 size = None ,
735- ndim_supp : int ,
782+ signature : str ,
736783 class_name : str ,
737784 ):
738785 size = normalize_size_param (size )
@@ -745,6 +792,10 @@ def rv_op(
745792 dummy_params = [dummy_size_param , * dummy_dist_params ]
746793 dummy_updates_dict = collect_default_updates (inputs = dummy_params , outputs = (dummy_rv ,))
747794
795+ signature = cls ._infer_final_signature (
796+ signature , len (dummy_params ), len (dummy_updates_dict )
797+ )
798+
748799 rv_type = type (
749800 class_name ,
750801 (CustomSymbolicDistRV ,),
@@ -802,7 +853,7 @@ def change_custom_symbolic_dist_size(op, rv, new_size, expand):
802853 new_rv_op = rv_type (
803854 inputs = dummy_params ,
804855 outputs = [* dummy_updates_dict .values (), dummy_rv ],
805- ndim_supp = ndim_supp ,
856+ signature = signature ,
806857 )
807858 new_rv = new_rv_op (new_size , * dist_params )
808859
@@ -811,10 +862,30 @@ def change_custom_symbolic_dist_size(op, rv, new_size, expand):
811862 rv_op = rv_type (
812863 inputs = dummy_params ,
813864 outputs = [* dummy_updates_dict .values (), dummy_rv ],
814- ndim_supp = ndim_supp ,
865+ signature = signature ,
815866 )
816867 return rv_op (size , * dist_params )
817868
869+ @staticmethod
870+ def _infer_final_signature (signature : str , n_inputs , n_updates ) -> str :
871+ """Add size and updates to user provided gufunc signature if they are missing."""
872+ input_sig , output_sig = signature .split ("->" )
873+ # Numpy parser does not accept (constant) functions without inputs like "->()"
874+ # We work around as this makes sense for distributions like Flat that have no inputs
875+ if input_sig .strip () == "" :
876+ inputs = ()
877+ _ , outputs = _parse_gufunc_signature ("()" + signature )
878+ else :
879+ inputs , outputs = _parse_gufunc_signature (signature )
880+ if len (inputs ) == n_inputs - 1 :
881+ # Assume size is missing
882+ input_sig = ("()," if input_sig else "()" ) + input_sig
883+ if len (outputs ) == 1 :
884+ # Assume updates are missing
885+ output_sig = "()," * n_updates + output_sig
886+ signature = "->" .join ((input_sig , output_sig ))
887+ return signature
888+
818889
819890class CustomDist :
820891 """A helper class to create custom distributions
@@ -828,12 +899,12 @@ class CustomDist:
828899 when not provided by the user.
829900
830901 Alternatively, a user can provide a `random` function that returns numerical
831- draws (e.g., via NumPy routines), and a `logp` function that must return an
832- Python graph that represents the logp graph when evaluated. This is used for
902+ draws (e.g., via NumPy routines), and a `logp` function that must return a
903+ PyTensor graph that represents the logp graph when evaluated. This is used for
833904 mcmc sampling.
834905
835906 Additionally, a user can provide a `logcdf` and `support_point` functions that must return
836- an PyTensor graph that computes those quantities. These may be used by other PyMC
907+ PyTensor graphs that computes those quantities. These may be used by other PyMC
837908 routines.
838909
839910 Parameters
@@ -894,14 +965,18 @@ class CustomDist:
894965 distribution parameters, in the same order as they were supplied when the
895966 CustomDist was created. If ``None``, a default ``support_point`` function will be
896967 assigned that will always return 0, or an array of zeros.
897- ndim_supp : int
898- The number of dimensions in the support of the distribution. Defaults to assuming
899- a scalar distribution, i.e. ``ndim_supp = 0``.
968+ ndim_supp : Optional[int]
969+ The number of dimensions in the support of the distribution.
970+ Inferred from signature, if provided. Defaults to assuming
971+ a scalar distribution, i.e. ``ndim_supp = 0``
900972 ndims_params : Optional[Sequence[int]]
901973 The list of number of dimensions in the support of each of the distribution's
902- parameters. If ``None``, it is assumed that all parameters are scalars, hence
903- the number of dimensions of their support will be 0. This is not needed if an
904- PyTensor dist function is provided.
974+ parameters. Inferred from signature, if provided. Defaults to assuming
975+ all parameters are scalars, i.e. ``ndims_params=[0, ...]``.
976+ signature : Optional[str]
977+ A numpy vectorize-like signature that indicates the number and core dimensionality
978+ of the input parameters and sample outputs of the CustomDist.
979+ When specified, `ndim_supp` and `ndims_params` are not needed. See examples below.
905980 dtype : str
906981 The dtype of the distribution. All draws and observations passed into the
907982 distribution will be cast onto this dtype. This is not needed if an PyTensor
@@ -939,6 +1014,7 @@ def logp(value: TensorVariable, mu: TensorVariable) -> TensorVariable:
9391014
9401015 Provide a random function that return numerical draws. This allows one to use a
9411016 CustomDist in prior and posterior predictive sampling.
1017+ A gufunc signature was also provided, which may be used by other routines.
9421018
9431019 .. code-block:: python
9441020
@@ -965,6 +1041,7 @@ def random(
9651041 mu,
9661042 logp=logp,
9671043 random=random,
1044+ signature="()->()",
9681045 observed=np.random.randn(100, 3),
9691046 size=(100, 3),
9701047 )
@@ -973,6 +1050,7 @@ def random(
9731050 Provide a dist function that creates a PyTensor graph built from other
9741051 PyMC distributions. PyMC can automatically infer that the logp of this
9751052 variable corresponds to a shifted Exponential distribution.
1053+ A gufunc signature was also provided, which may be used by other routines.
9761054
9771055 .. code-block:: python
9781056
@@ -994,6 +1072,7 @@ def dist(
9941072 lam,
9951073 shift,
9961074 dist=dist,
1075+ signature="(),()->()",
9971076 observed=[-1, -1, 0],
9981077 )
9991078
@@ -1040,10 +1119,11 @@ def __new__(
10401119 random : Optional [Callable ] = None ,
10411120 logp : Optional [Callable ] = None ,
10421121 logcdf : Optional [Callable ] = None ,
1043- moment : Optional [Callable ] = None ,
10441122 support_point : Optional [Callable ] = None ,
1045- ndim_supp : int = 0 ,
1123+ # TODO: Deprecate ndim_supp / ndims_params in favor of signature?
1124+ ndim_supp : Optional [int ] = None ,
10461125 ndims_params : Optional [Sequence [int ]] = None ,
1126+ signature : Optional [str ] = None ,
10471127 dtype : str = "floatX" ,
10481128 ** kwargs ,
10491129 ):
@@ -1057,6 +1137,7 @@ def __new__(
10571137 )
10581138 dist_params = cls .parse_dist_params (dist_params )
10591139 cls .check_valid_dist_random (dist , random , dist_params )
1140+ moment = kwargs .pop ("moment" , None )
10601141 if moment is not None :
10611142 warnings .warn (
10621143 "`moment` argument is deprecated. Use `support_point` instead." ,
@@ -1073,6 +1154,8 @@ def __new__(
10731154 logcdf = logcdf ,
10741155 support_point = support_point ,
10751156 ndim_supp = ndim_supp ,
1157+ ndims_params = ndims_params ,
1158+ signature = signature ,
10761159 ** kwargs ,
10771160 )
10781161 else :
@@ -1086,6 +1169,7 @@ def __new__(
10861169 support_point = support_point ,
10871170 ndim_supp = ndim_supp ,
10881171 ndims_params = ndims_params ,
1172+ signature = signature ,
10891173 dtype = dtype ,
10901174 ** kwargs ,
10911175 )
@@ -1099,8 +1183,9 @@ def dist(
10991183 logp : Optional [Callable ] = None ,
11001184 logcdf : Optional [Callable ] = None ,
11011185 support_point : Optional [Callable ] = None ,
1102- ndim_supp : int = 0 ,
1186+ ndim_supp : Optional [ int ] = None ,
11031187 ndims_params : Optional [Sequence [int ]] = None ,
1188+ signature : Optional [str ] = None ,
11041189 dtype : str = "floatX" ,
11051190 ** kwargs ,
11061191 ):
@@ -1114,6 +1199,8 @@ def dist(
11141199 logcdf = logcdf ,
11151200 support_point = support_point ,
11161201 ndim_supp = ndim_supp ,
1202+ ndims_params = ndims_params ,
1203+ signature = signature ,
11171204 ** kwargs ,
11181205 )
11191206 else :
@@ -1125,6 +1212,7 @@ def dist(
11251212 support_point = support_point ,
11261213 ndim_supp = ndim_supp ,
11271214 ndims_params = ndims_params ,
1215+ signature = signature ,
11281216 dtype = dtype ,
11291217 ** kwargs ,
11301218 )
0 commit comments