Skip to content

Commit 81da946

Browse files
committed
Add _logccdf dispatcher for numerically stable log survival function
Add _logccdf (log complementary CDF / log survival function) support: - pymc/logprob/abstract.py: Add _logccdf singledispatch and _logccdf_helper - pymc/distributions/distribution.py: Register logccdf methods via metaclass - pymc/distributions/continuous.py: Add logccdf to Normal using stable normal_lccdf - pymc/logprob/censoring.py: Use _logccdf for right-censored distributions - pymc/logprob/binary.py: Use _logccdf for comparison operations - pymc/logprob/transforms.py: Use _logccdf_helper for monotonic transforms - pymc/logprob/basic.py: Add public logccdf() function - pymc/logprob/__init__.py: Export logccdf This fixes numerical instability when computing log-probabilities for censored Normal distributions at extreme tail values (e.g., 10+ sigma).
1 parent 87eba5d commit 81da946

File tree

9 files changed

+205
-7
lines changed

9 files changed

+205
-7
lines changed

pymc/distributions/continuous.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,13 @@ def logcdf(value, mu, sigma):
512512
msg="sigma > 0",
513513
)
514514

515+
def logccdf(value, mu, sigma):
516+
return check_parameters(
517+
normal_lccdf(mu, sigma, value),
518+
sigma > 0,
519+
msg="sigma > 0",
520+
)
521+
515522
def icdf(value, mu, sigma):
516523
res = mu + sigma * -np.sqrt(2.0) * pt.erfcinv(2 * value)
517524
res = check_icdf_value(res, value)

pymc/distributions/distribution.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
rv_size_is_none,
5151
shape_from_dims,
5252
)
53-
from pymc.logprob.abstract import MeasurableOp, _icdf, _logcdf, _logprob
53+
from pymc.logprob.abstract import MeasurableOp, _icdf, _logccdf, _logcdf, _logprob
5454
from pymc.logprob.basic import logp
5555
from pymc.logprob.rewriting import logprob_rewrites_db
5656
from pymc.printing import str_for_dist
@@ -150,6 +150,17 @@ def logcdf(op, value, *dist_params, **kwargs):
150150
dist_params = [dist_params[i] for i in params_idxs]
151151
return class_logcdf(value, *dist_params)
152152

153+
class_logccdf = clsdict.get("logccdf")
154+
if class_logccdf:
155+
156+
@_logccdf.register(rv_type)
157+
def logccdf(op, value, *dist_params, **kwargs):
158+
if isinstance(op, RandomVariable):
159+
rng, size, *dist_params = dist_params
160+
elif params_idxs:
161+
dist_params = [dist_params[i] for i in params_idxs]
162+
return class_logccdf(value, *dist_params)
163+
153164
class_icdf = clsdict.get("icdf")
154165
if class_icdf:
155166

pymc/logprob/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from pymc.logprob.basic import (
4040
conditional_logp,
4141
icdf,
42+
logccdf,
4243
logcdf,
4344
logp,
4445
transformed_conditional_logp,
@@ -59,6 +60,7 @@
5960

6061
__all__ = (
6162
"icdf",
63+
"logccdf",
6264
"logcdf",
6365
"logp",
6466
)

pymc/logprob/abstract.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,36 @@ def _logcdf_helper(rv, value, **kwargs):
108108
return logcdf
109109

110110

111+
@singledispatch
112+
def _logccdf(
113+
op: Op,
114+
value: TensorVariable,
115+
*inputs: TensorVariable,
116+
**kwargs,
117+
):
118+
"""Create a graph for the log complementary CDF (log survival function) of a ``RandomVariable``.
119+
120+
This function dispatches on the type of ``op``, which should be a subclass
121+
of ``RandomVariable``. If you want to implement new logccdf graphs
122+
for a ``RandomVariable``, register a new function on this dispatcher.
123+
124+
The log complementary CDF is defined as log(1 - CDF(x)), also known as the
125+
log survival function. For distributions with a numerically stable implementation,
126+
this should be used instead of computing log(1 - exp(logcdf)).
127+
"""
128+
raise NotImplementedError(f"LogCCDF method not implemented for {op}")
129+
130+
131+
def _logccdf_helper(rv, value, **kwargs):
132+
"""Helper that calls `_logccdf` dispatcher."""
133+
logccdf = _logccdf(rv.owner.op, value, *rv.owner.inputs, name=rv.name, **kwargs)
134+
135+
if rv.name:
136+
logccdf.name = f"{rv.name}_logccdf"
137+
138+
return logccdf
139+
140+
111141
@singledispatch
112142
def _icdf(
113143
op: Op,

pymc/logprob/basic.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
from pymc.logprob.abstract import (
5454
MeasurableOp,
5555
_icdf_helper,
56+
_logccdf_helper,
5657
_logcdf_helper,
5758
_logprob,
5859
_logprob_helper,
@@ -302,6 +303,69 @@ def normal_logcdf(value, mu, sigma):
302303
return expr
303304

304305

306+
def logccdf(rv: TensorVariable, value: TensorLike, warn_rvs=True, **kwargs) -> TensorVariable:
307+
"""Create a graph for the log complementary CDF (log survival function) of a random variable.
308+
309+
The log complementary CDF is defined as log(1 - CDF(x)), also known as the
310+
log survival function. For distributions with a numerically stable implementation,
311+
this is more accurate than computing log(1 - exp(logcdf)).
312+
313+
Parameters
314+
----------
315+
rv : TensorVariable
316+
value : tensor_like
317+
Should be the same type (shape and dtype) as the rv.
318+
warn_rvs : bool, default True
319+
Warn if RVs were found in the logccdf graph.
320+
This can happen when a variable has other random variables as inputs.
321+
In that case, those random variables should be replaced by their respective values.
322+
323+
Returns
324+
-------
325+
logccdf : TensorVariable
326+
327+
Raises
328+
------
329+
RuntimeError
330+
If the logccdf cannot be derived.
331+
332+
Examples
333+
--------
334+
Create a compiled function that evaluates the logccdf of a variable
335+
336+
.. code-block:: python
337+
338+
import pymc as pm
339+
import pytensor.tensor as pt
340+
341+
mu = pt.scalar("mu")
342+
rv = pm.Normal.dist(mu, 1.0)
343+
344+
value = pt.scalar("value")
345+
rv_logccdf = pm.logccdf(rv, value)
346+
347+
# Use .eval() for debugging
348+
print(rv_logccdf.eval({value: 0.9, mu: 0.0})) # -1.5272506
349+
350+
# Compile a function for repeated evaluations
351+
rv_logccdf_fn = pm.compile_pymc([value, mu], rv_logccdf)
352+
print(rv_logccdf_fn(value=0.9, mu=0.0)) # -1.5272506
353+
354+
"""
355+
value = pt.as_tensor_variable(value, dtype=rv.dtype)
356+
try:
357+
return _logccdf_helper(rv, value, **kwargs)
358+
except NotImplementedError:
359+
# Try to rewrite rv
360+
fgraph, _, _ = construct_ir_fgraph({rv: value})
361+
[ir_rv] = fgraph.outputs
362+
expr = _logccdf_helper(ir_rv, value, **kwargs)
363+
[expr] = cleanup_ir([expr])
364+
if warn_rvs:
365+
_warn_rvs_in_inferred_graph([expr])
366+
return expr
367+
368+
305369
def icdf(rv: TensorVariable, value: TensorLike, warn_rvs=True, **kwargs) -> TensorVariable:
306370
"""Create a graph for the inverse CDF of a random variable.
307371

pymc/logprob/binary.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
from pymc.logprob.abstract import (
2727
MeasurableElemwise,
28+
_logccdf,
2829
_logcdf_helper,
2930
_logprob,
3031
_logprob_helper,
@@ -95,7 +96,12 @@ def comparison_logprob(op, values, base_rv, operand, **kwargs):
9596
base_rv_op = base_rv.owner.op
9697

9798
logcdf = _logcdf_helper(base_rv, operand, **kwargs)
98-
logccdf = pt.log1mexp(logcdf)
99+
# Try to use a numerically stable logccdf if available, otherwise fall back
100+
# to computing log(1 - exp(logcdf)) which can be unstable in the tails
101+
try:
102+
logccdf = _logccdf(base_rv_op, operand, *base_rv.owner.inputs, **kwargs)
103+
except NotImplementedError:
104+
logccdf = pt.log1mexp(logcdf)
99105

100106
condn_exp = pt.eq(value, np.array(True))
101107

pymc/logprob/censoring.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
from pytensor.tensor.math import ceil, clip, floor, round_half_to_even
4848
from pytensor.tensor.variable import TensorConstant
4949

50-
from pymc.logprob.abstract import MeasurableElemwise, _logcdf, _logprob
50+
from pymc.logprob.abstract import MeasurableElemwise, _logccdf, _logcdf, _logprob
5151
from pymc.logprob.rewriting import measurable_ir_rewrites_db
5252
from pymc.logprob.utils import CheckParameterValue, filter_measurable_variables
5353

@@ -119,7 +119,13 @@ def clip_logprob(op, values, base_rv, lower_bound, upper_bound, **kwargs):
119119
if not (isinstance(upper_bound, TensorConstant) and np.all(np.isinf(upper_bound.value))):
120120
is_upper_bounded = True
121121

122-
logccdf = pt.log1mexp(logcdf)
122+
# Try to use a numerically stable logccdf if available, otherwise fall back
123+
# to computing log(1 - exp(logcdf)) which can be unstable in the tails
124+
try:
125+
logccdf = _logccdf(base_rv_op, value, *base_rv_inputs, **kwargs)
126+
except NotImplementedError:
127+
logccdf = pt.log1mexp(logcdf)
128+
123129
# For right clipped discrete RVs, we need to add an extra term
124130
# corresponding to the pmf at the upper bound
125131
if base_rv.dtype.startswith("int"):

pymc/logprob/transforms.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@
111111
MeasurableOp,
112112
_icdf,
113113
_icdf_helper,
114+
_logccdf_helper,
114115
_logcdf,
115116
_logcdf_helper,
116117
_logprob,
@@ -248,9 +249,15 @@ def measurable_transform_logcdf(op: MeasurableTransform, value, *inputs, **kwarg
248249

249250
logcdf = _logcdf_helper(measurable_input, backward_value)
250251
if is_discrete:
252+
# For discrete distributions, use the logcdf at the previous value
251253
logccdf = pt.log1mexp(_logcdf_helper(measurable_input, backward_value - 1))
252254
else:
253-
logccdf = pt.log1mexp(logcdf)
255+
# Try to use a numerically stable logccdf if available, otherwise fall back
256+
# to computing log(1 - exp(logcdf)) which can be unstable in the tails
257+
try:
258+
logccdf = _logccdf_helper(measurable_input, backward_value)
259+
except NotImplementedError:
260+
logccdf = pt.log1mexp(logcdf)
254261

255262
if isinstance(op.scalar_op, MONOTONICALLY_INCREASING_OPS):
256263
pass

tests/logprob/test_abstract.py

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,13 @@
4545

4646
import pymc as pm
4747

48-
from pymc.logprob.abstract import MeasurableElemwise, MeasurableOp, _logcdf_helper
49-
from pymc.logprob.basic import logcdf
48+
from pymc.logprob.abstract import (
49+
MeasurableElemwise,
50+
MeasurableOp,
51+
_logccdf_helper,
52+
_logcdf_helper,
53+
)
54+
from pymc.logprob.basic import logccdf, logcdf
5055

5156

5257
def assert_equal_hash(classA, classB):
@@ -80,6 +85,38 @@ def test_logcdf_helper():
8085
np.testing.assert_almost_equal(x_logcdf.eval(), sp.norm(0, 1).logcdf([0, 1]))
8186

8287

88+
def test_logccdf_helper():
89+
value = pt.vector("value")
90+
x = pm.Normal.dist(0, 1)
91+
92+
x_logccdf = _logccdf_helper(x, value)
93+
np.testing.assert_almost_equal(x_logccdf.eval({value: [0, 1]}), sp.norm(0, 1).logsf([0, 1]))
94+
95+
x_logccdf = _logccdf_helper(x, [0, 1])
96+
np.testing.assert_almost_equal(x_logccdf.eval(), sp.norm(0, 1).logsf([0, 1]))
97+
98+
99+
def test_logccdf_helper_numerical_stability():
100+
"""Test that logccdf is numerically stable in the far right tail.
101+
102+
This is where log(1 - exp(logcdf)) would lose precision because CDF is very close to 1.
103+
"""
104+
x = pm.Normal.dist(0, 1)
105+
106+
# Test value far in the right tail where CDF is essentially 1
107+
far_tail_value = 10.0
108+
109+
x_logccdf = _logccdf_helper(x, far_tail_value)
110+
result = x_logccdf.eval()
111+
112+
# scipy.stats.norm.logsf uses a numerically stable implementation
113+
expected = sp.norm(0, 1).logsf(far_tail_value)
114+
115+
# The naive computation would give log(1 - 1) = -inf or very wrong values
116+
# The stable implementation should match scipy's logsf closely
117+
np.testing.assert_almost_equal(result, expected, decimal=6)
118+
119+
83120
def test_logcdf_transformed_argument():
84121
with pm.Model() as m:
85122
sigma = pm.HalfFlat("sigma")
@@ -95,3 +132,31 @@ def test_logcdf_transformed_argument():
95132
pm.TruncatedNormal.dist(0, sigma_value, lower=None, upper=1.0), x_value
96133
).eval()
97134
assert np.isclose(observed, expected)
135+
136+
137+
def test_logccdf():
138+
"""Test the public logccdf function."""
139+
value = pt.vector("value")
140+
x = pm.Normal.dist(0, 1)
141+
142+
x_logccdf = logccdf(x, value)
143+
np.testing.assert_almost_equal(x_logccdf.eval({value: [0, 1]}), sp.norm(0, 1).logsf([0, 1]))
144+
145+
146+
def test_logccdf_numerical_stability():
147+
"""Test that pm.logccdf is numerically stable in the extreme right tail.
148+
149+
For a normal distribution, the log survival function at x=10 is very negative
150+
(around -52). Using log(1 - exp(logcdf)) would fail because CDF(10) is essentially 1.
151+
"""
152+
x = pm.Normal.dist(0, 1)
153+
154+
# Test value far in the right tail
155+
far_tail_value = 10.0
156+
157+
result = logccdf(x, far_tail_value).eval()
158+
expected = sp.norm(0, 1).logsf(far_tail_value)
159+
160+
# Should be around -52, not -inf or nan
161+
assert np.isfinite(result)
162+
np.testing.assert_almost_equal(result, expected, decimal=6)

0 commit comments

Comments
 (0)