@@ -2700,34 +2700,60 @@ class ZeroSumNormal(Distribution):
27002700 r"""
27012701 Normal distribution where one or several axes are constrained to sum to zero.
27022702
2703- By default, the last axis is constrained to sum to zero. See the `n_zerosum_axes`
2704- kwarg for more details.
2705-
2706- The constrained distribution follows a multivariate Normal distribution. For the
2707- standard 1D case with a single constrained axis of size K, the covariance is:
2703+ By default, the last axis is constrained to sum to zero.
2704+ See `n_zerosum_axes` kwarg for more details.
27082705
27092706 .. math::
27102707
2711- ZSN(\sigma) = N\left(0, \sigma^2 \left(I_K - \tfrac{1}{K} J_K\right)\right)
2712-
2713- where:
2714-
2715- - :math:`I_K` is the :math:`K \times K` identity matrix,
2716- - :math:`J_K` is the :math:`K \times K` matrix of ones,
2717- - :math:`K` is the size of the constrained axis.
2718-
2719- Using :math:`K` avoids confusion with ``n_zerosum_axes``, which counts how many
2720- axes are constrained, not their length.
2708+ \begin{align*}
2709+ ZSN(\sigma) = N \Big( 0, \sigma^2 (I_K - \tfrac{1}{K}J_K) \Big) \\
2710+ \text{where} \ ~ J_{ij} = 1 \ ~ \text{and} \\
2711+ K = \text{size (length) of the constrained axis}
2712+ \end{align*}
27212713
27222714 Parameters
27232715 ----------
27242716 sigma : tensor_like of float
2725- Scale parameter (sigma > 0). Defaults to 1.
2726- ``sigma`` cannot have length > 1 across the zero-sum axes.
2727- n_zerosum_axes : int, defaults to 1
2728- Number of axes along which the zero-sum constraint is enforced.
2729- dims : sequence of strings, optional
2730- shape : tuple of integers, optional
2717+ Scale parameter (sigma > 0).
2718+ It's actually the standard deviation of the underlying, unconstrained Normal distribution.
2719+ Defaults to 1 if not specified. ``sigma`` cannot have length > 1 across the zero-sum axes.
2720+ n_zerosum_axes: int, defaults to 1
2721+ Number of axes along which the zero-sum constraint is enforced, starting from the rightmost position.
2722+ Defaults to 1, i.e the rightmost axis.
2723+ dims: sequence of strings, optional
2724+ Dimension names of the distribution. Works the same as for other PyMC distributions.
2725+ Necessary if ``shape`` is not passed.
2726+ shape: tuple of integers, optional
2727+ Shape of the distribution. Works the same as for other PyMC distributions.
2728+ Necessary if ``dims`` or ``observed`` is not passed.
2729+
2730+ Warnings
2731+ --------
2732+ Currently, ``sigma`` cannot have length > 1 across the zero-sum axes to ensure the zero-sum constraint.
2733+
2734+ ``n_zerosum_axes`` has to be > 0. If you want the behavior of ``n_zerosum_axes = 0``,
2735+ just use ``pm.Normal``.
2736+
2737+ Examples
2738+ --------
2739+ Define a `ZeroSumNormal` variable, with `sigma=1` and
2740+ `n_zerosum_axes=1` by default::
2741+
2742+ COORDS = {
2743+ "regions": ["a", "b", "c"],
2744+ "answers": ["yes", "no", "whatever", "don't understand question"],
2745+ }
2746+ with pm.Model(coords=COORDS) as m:
2747+ # the zero sum axis will be 'answers'
2748+ v = pm.ZeroSumNormal("v", dims=("regions", "answers"))
2749+
2750+ with pm.Model(coords=COORDS) as m:
2751+ # the zero sum axes will be 'answers' and 'regions'
2752+ v = pm.ZeroSumNormal("v", dims=("regions", "answers"), n_zerosum_axes=2)
2753+
2754+ with pm.Model(coords=COORDS) as m:
2755+ # the zero sum axes will be the last two
2756+ v = pm.ZeroSumNormal("v", shape=(3, 4, 5), n_zerosum_axes=2)
27312757 """
27322758
27332759 rv_type = ZeroSumNormalRV
0 commit comments