Skip to content

Commit d718206

Browse files
committed
Fix ZeroSumNormal docstring: use K instead of n in covariance formula
1 parent 02235a0 commit d718206

File tree

1 file changed

+47
-21
lines changed

1 file changed

+47
-21
lines changed

pymc/distributions/multivariate.py

Lines changed: 47 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2700,34 +2700,60 @@ class ZeroSumNormal(Distribution):
27002700
r"""
27012701
Normal distribution where one or several axes are constrained to sum to zero.
27022702
2703-
By default, the last axis is constrained to sum to zero. See the `n_zerosum_axes`
2704-
kwarg for more details.
2705-
2706-
The constrained distribution follows a multivariate Normal distribution. For the
2707-
standard 1D case with a single constrained axis of size K, the covariance is:
2703+
By default, the last axis is constrained to sum to zero.
2704+
See `n_zerosum_axes` kwarg for more details.
27082705
27092706
.. math::
27102707
2711-
ZSN(\sigma) = N\left(0, \sigma^2 \left(I_K - \tfrac{1}{K} J_K\right)\right)
2712-
2713-
where:
2714-
2715-
- :math:`I_K` is the :math:`K \times K` identity matrix,
2716-
- :math:`J_K` is the :math:`K \times K` matrix of ones,
2717-
- :math:`K` is the size of the constrained axis.
2718-
2719-
Using :math:`K` avoids confusion with ``n_zerosum_axes``, which counts how many
2720-
axes are constrained, not their length.
2708+
\begin{align*}
2709+
ZSN(\sigma) = N \Big( 0, \sigma^2 (I_K - \tfrac{1}{K}J_K) \Big) \\
2710+
\text{where} \ ~ J_{ij} = 1 \ ~ \text{and} \\
2711+
K = \text{size (length) of the constrained axis}
2712+
\end{align*}
27212713
27222714
Parameters
27232715
----------
27242716
sigma : tensor_like of float
2725-
Scale parameter (sigma > 0). Defaults to 1.
2726-
``sigma`` cannot have length > 1 across the zero-sum axes.
2727-
n_zerosum_axes : int, defaults to 1
2728-
Number of axes along which the zero-sum constraint is enforced.
2729-
dims : sequence of strings, optional
2730-
shape : tuple of integers, optional
2717+
Scale parameter (sigma > 0).
2718+
It's actually the standard deviation of the underlying, unconstrained Normal distribution.
2719+
Defaults to 1 if not specified. ``sigma`` cannot have length > 1 across the zero-sum axes.
2720+
n_zerosum_axes: int, defaults to 1
2721+
Number of axes along which the zero-sum constraint is enforced, starting from the rightmost position.
2722+
Defaults to 1, i.e the rightmost axis.
2723+
dims: sequence of strings, optional
2724+
Dimension names of the distribution. Works the same as for other PyMC distributions.
2725+
Necessary if ``shape`` is not passed.
2726+
shape: tuple of integers, optional
2727+
Shape of the distribution. Works the same as for other PyMC distributions.
2728+
Necessary if ``dims`` or ``observed`` is not passed.
2729+
2730+
Warnings
2731+
--------
2732+
Currently, ``sigma`` cannot have length > 1 across the zero-sum axes to ensure the zero-sum constraint.
2733+
2734+
``n_zerosum_axes`` has to be > 0. If you want the behavior of ``n_zerosum_axes = 0``,
2735+
just use ``pm.Normal``.
2736+
2737+
Examples
2738+
--------
2739+
Define a `ZeroSumNormal` variable, with `sigma=1` and
2740+
`n_zerosum_axes=1` by default::
2741+
2742+
COORDS = {
2743+
"regions": ["a", "b", "c"],
2744+
"answers": ["yes", "no", "whatever", "don't understand question"],
2745+
}
2746+
with pm.Model(coords=COORDS) as m:
2747+
# the zero sum axis will be 'answers'
2748+
v = pm.ZeroSumNormal("v", dims=("regions", "answers"))
2749+
2750+
with pm.Model(coords=COORDS) as m:
2751+
# the zero sum axes will be 'answers' and 'regions'
2752+
v = pm.ZeroSumNormal("v", dims=("regions", "answers"), n_zerosum_axes=2)
2753+
2754+
with pm.Model(coords=COORDS) as m:
2755+
# the zero sum axes will be the last two
2756+
v = pm.ZeroSumNormal("v", shape=(3, 4, 5), n_zerosum_axes=2)
27312757
"""
27322758

27332759
rv_type = ZeroSumNormalRV

0 commit comments

Comments
 (0)