Skip to content

Commit 87eba5d

Browse files
committed
Add numerical stability test for censored distributions
Test that pm.Censored computes log-probabilities stably at the bounds: - Right censoring (upper bound): log(1 - CDF) when CDF ≈ 1 - Left censoring (lower bound): log(CDF) when CDF ≈ 0 Uses pm.Censored with Normal(0, 1) at ±40 standard deviations.
1 parent cadb97a commit 87eba5d

File tree

1 file changed

+46
-0
lines changed

1 file changed

+46
-0
lines changed

tests/logprob/test_censoring.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,3 +261,49 @@ def test_rounding(rounding_op):
261261
logprob.eval({xr_vv: test_value}),
262262
expected_logp,
263263
)
264+
265+
266+
@pytest.mark.parametrize(
267+
"censoring_side,bound_value",
268+
[
269+
("right", 40.0), # Far right tail: CDF ≈ 1, need stable log(1-CDF)
270+
("left", -40.0), # Far left tail: CDF ≈ 0, need stable log(CDF)
271+
],
272+
)
273+
def test_censored_logprob_numerical_stability(censoring_side, bound_value):
274+
"""Test that censored distributions use numerically stable log-probability computations.
275+
276+
For right-censoring at the upper bound, log(1 - CDF) is computed. When CDF ≈ 1
277+
(far right tail), this requires a stable logccdf implementation.
278+
279+
For left-censoring at the lower bound, log(CDF) is computed. When CDF ≈ 0
280+
(far left tail), this requires a stable logcdf implementation.
281+
282+
This test uses pm.Censored which is the high-level API for censored distributions.
283+
"""
284+
import pymc as pm
285+
286+
ref_scipy = st.norm(0, 1)
287+
288+
with pm.Model() as model:
289+
normal_dist = pm.Normal.dist(mu=0.0, sigma=1.0)
290+
if censoring_side == "right":
291+
pm.Censored("y", normal_dist, lower=None, upper=bound_value)
292+
expected_logp = ref_scipy.logsf(bound_value) # log(1 - CDF)
293+
else: # left
294+
pm.Censored("y", normal_dist, lower=bound_value, upper=None)
295+
expected_logp = ref_scipy.logcdf(bound_value) # log(CDF)
296+
297+
# Compile the logp function
298+
logp_fn = model.compile_logp()
299+
300+
# Evaluate at the bound - this is where the log survival/cdf function is used
301+
logp_at_bound = logp_fn({"y": bound_value})
302+
303+
# This should be finite and correct, not -inf
304+
assert np.isfinite(logp_at_bound), (
305+
f"logp at {censoring_side} bound should be finite, got {logp_at_bound}"
306+
)
307+
assert np.isclose(logp_at_bound, expected_logp, rtol=1e-6), (
308+
f"logp at {censoring_side} bound: got {logp_at_bound}, expected {expected_logp}"
309+
)

0 commit comments

Comments
 (0)