Skip to content

Commit 3edaae2

Browse files
committed
fix: apply numerical stability hotfix for censored Normal distributions
Import censored_distribution_stability_hotfix when PyMC is available. This fixes numerical instability in pm.Censored for Normal distributions when values are in the extreme tails. Addresses: brendanjmeade#341 Upstream fix: pymc-devs/pymc#7996
1 parent dea2ce0 commit 3edaae2

File tree

2 files changed

+165
-0
lines changed

2 files changed

+165
-0
lines changed
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
"""Hotfix module providing numerically stable log-probabilities for censored Normal distributions.
2+
3+
This is a temporary workaround until https://github.com/pymc-devs/pymc/pull/7996 is merged.
4+
5+
The fix monkey-patches the MeasurableClip logprob to use a stable log survival function
6+
for Normal distributions instead of the numerically unstable log(1 - exp(logcdf)).
7+
8+
Usage:
9+
import stable_censored_hotfix # Just import to apply the fix
10+
11+
with pm.Model():
12+
normal_dist = pm.Normal.dist(mu=0.0, sigma=1.0)
13+
y = pm.Censored("y", normal_dist, lower=None, upper=40.0, observed=data)
14+
"""
15+
16+
import numpy as np
17+
import pytensor.tensor as pt
18+
from pymc.distributions.dist_math import normal_lccdf
19+
from pymc.logprob.abstract import _logcdf, _logprob
20+
from pymc.logprob.censoring import MeasurableClip
21+
from pymc.logprob.utils import CheckParameterValue
22+
from pytensor.tensor.variable import TensorConstant
23+
24+
25+
def _stable_normal_logccdf(mu, sigma, value):
26+
"""Numerically stable log complementary CDF (log survival function) for Normal.
27+
28+
Uses erfcx-based implementation that is stable even in extreme tails.
29+
"""
30+
return normal_lccdf(mu, sigma, value)
31+
32+
33+
def _get_stable_logccdf(base_rv_op, base_rv_inputs, value, logcdf_fallback):
34+
"""Get numerically stable log complementary CDF if available.
35+
36+
For Normal distribution, uses the stable erfcx-based implementation.
37+
For other distributions, falls back to log1mexp(logcdf).
38+
"""
39+
from pytensor.tensor.random.basic import NormalRV
40+
41+
if isinstance(base_rv_op, NormalRV):
42+
# Normal distribution: use stable implementation
43+
# base_rv_inputs are: rng, size, mu, sigma
44+
rng, size, mu, sigma = base_rv_inputs
45+
return _stable_normal_logccdf(mu, sigma, value)
46+
else:
47+
# Fall back to potentially unstable computation
48+
return pt.log1mexp(logcdf_fallback)
49+
50+
51+
def _stable_clip_logprob(op, values, base_rv, lower_bound, upper_bound, **kwargs):
52+
r"""Stable logprob of a clipped censored distribution.
53+
54+
The probability is given by
55+
.. math::
56+
\begin{cases}
57+
0 & \text{for } x < lower, \\
58+
\text{CDF}(lower, dist) & \text{for } x = lower, \\
59+
\text{P}(x, dist) & \text{for } lower < x < upper, \\
60+
1-\text{CDF}(upper, dist) & \text {for} x = upper, \\
61+
0 & \text{for } x > upper,
62+
\end{cases}
63+
64+
"""
65+
(value,) = values
66+
67+
base_rv_op = base_rv.owner.op
68+
base_rv_inputs = base_rv.owner.inputs
69+
70+
logprob = _logprob(base_rv_op, (value,), *base_rv_inputs, **kwargs)
71+
logcdf = _logcdf(base_rv_op, value, *base_rv_inputs, **kwargs)
72+
73+
if base_rv_op.name:
74+
logprob.name = f"{base_rv_op}_logprob"
75+
logcdf.name = f"{base_rv_op}_logcdf"
76+
77+
is_lower_bounded, is_upper_bounded = False, False
78+
if not (
79+
isinstance(upper_bound, TensorConstant) and np.all(np.isinf(upper_bound.value))
80+
):
81+
is_upper_bounded = True
82+
83+
# Use stable logccdf for Normal distributions instead of pt.log1mexp(logcdf)
84+
logccdf = _get_stable_logccdf(base_rv_op, base_rv_inputs, value, logcdf)
85+
86+
# For right clipped discrete RVs, we need to add an extra term
87+
# corresponding to the pmf at the upper bound
88+
if base_rv.dtype.startswith("int"):
89+
logccdf = pt.logaddexp(logccdf, logprob)
90+
91+
logprob = pt.switch(
92+
pt.eq(value, upper_bound),
93+
logccdf,
94+
pt.switch(pt.gt(value, upper_bound), -np.inf, logprob),
95+
)
96+
if not (
97+
isinstance(lower_bound, TensorConstant)
98+
and np.all(np.isneginf(lower_bound.value))
99+
):
100+
is_lower_bounded = True
101+
logprob = pt.switch(
102+
pt.eq(value, lower_bound),
103+
logcdf,
104+
pt.switch(pt.lt(value, lower_bound), -np.inf, logprob),
105+
)
106+
107+
if is_lower_bounded and is_upper_bounded:
108+
logprob = CheckParameterValue("lower_bound <= upper_bound")(
109+
logprob, pt.all(pt.le(lower_bound, upper_bound))
110+
)
111+
112+
return logprob
113+
114+
115+
def _apply_fix():
116+
"""Apply the fix by overriding the singledispatch registry."""
117+
# Use the register decorator to replace the existing function
118+
_logprob.register(MeasurableClip, _stable_clip_logprob)
119+
120+
121+
# Apply the fix on import
122+
_apply_fix()
123+
124+
125+
def verify_fix():
126+
"""Verify that the stable implementation works correctly."""
127+
import pymc as pm
128+
import scipy.stats as st
129+
130+
with pm.Model() as model:
131+
normal_dist = pm.Normal.dist(mu=0.0, sigma=1.0)
132+
pm.Censored("y", normal_dist, lower=None, upper=40.0)
133+
134+
logp_fn = model.compile_logp()
135+
result = logp_fn({"y": 40.0})
136+
expected = st.norm(0, 1).logsf(40.0)
137+
138+
if not np.isfinite(result):
139+
raise RuntimeError(
140+
f"Stable censored fix not working: got {result}, expected {expected}"
141+
)
142+
143+
if not np.isclose(result, expected, rtol=1e-6):
144+
raise RuntimeError(
145+
f"Stable censored result mismatch: got {result}, expected {expected}"
146+
)
147+
148+
return True
149+
150+
151+
if __name__ == "__main__":
152+
print("Verifying stable censored fix...")
153+
verify_fix()
154+
print("✓ Stable censored fix is working correctly!")
155+
print("\nUsage:")
156+
print(" import stable_censored_hotfix # Just import to apply the fix")
157+
print(" ")
158+
print(" with pm.Model():")
159+
print(" normal_dist = pm.Normal.dist(mu=0.0, sigma=1.0)")
160+
print(" y = pm.Censored('y', normal_dist, lower=None, upper=40.0)")

celeri/solve_mcmc.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@ class PymcModel:
1919
else:
2020
from pymc import Model as PymcModel
2121

22+
# Apply numerical stability fix for censored Normal distributions.
23+
# This is a workaround for https://github.com/pymc-devs/pymc/pull/7996
24+
# Fixes issue https://github.com/brendanjmeade/celeri/issues/341
25+
import celeri.censored_distribution_stability_hotfix # noqa: F401
26+
2227

2328
DIRECTION_IDX = {
2429
"strike_slip": slice(None, None, 2),

0 commit comments

Comments
 (0)