Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions pymc/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,13 @@ def logcdf(value, mu, sigma):
msg="sigma > 0",
)

def logccdf(value, mu, sigma):
return check_parameters(
normal_lccdf(mu, sigma, value),
sigma > 0,
msg="sigma > 0",
)

def icdf(value, mu, sigma):
res = mu + sigma * -np.sqrt(2.0) * pt.erfcinv(2 * value)
res = check_icdf_value(res, value)
Expand Down
13 changes: 12 additions & 1 deletion pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
rv_size_is_none,
shape_from_dims,
)
from pymc.logprob.abstract import MeasurableOp, _icdf, _logcdf, _logprob
from pymc.logprob.abstract import MeasurableOp, _icdf, _logccdf, _logcdf, _logprob
from pymc.logprob.basic import logp
from pymc.logprob.rewriting import logprob_rewrites_db
from pymc.printing import str_for_dist
Expand Down Expand Up @@ -150,6 +150,17 @@ def logcdf(op, value, *dist_params, **kwargs):
dist_params = [dist_params[i] for i in params_idxs]
return class_logcdf(value, *dist_params)

class_logccdf = clsdict.get("logccdf")
if class_logccdf:

@_logccdf.register(rv_type)
def logccdf(op, value, *dist_params, **kwargs):
if isinstance(op, RandomVariable):
rng, size, *dist_params = dist_params
elif params_idxs:
dist_params = [dist_params[i] for i in params_idxs]
return class_logccdf(value, *dist_params)

class_icdf = clsdict.get("icdf")
if class_icdf:

Expand Down
2 changes: 2 additions & 0 deletions pymc/logprob/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from pymc.logprob.basic import (
conditional_logp,
icdf,
logccdf,
logcdf,
logp,
transformed_conditional_logp,
Expand All @@ -59,6 +60,7 @@

__all__ = (
"icdf",
"logccdf",
"logcdf",
"logp",
)
30 changes: 30 additions & 0 deletions pymc/logprob/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,36 @@ def _logcdf_helper(rv, value, **kwargs):
return logcdf


@singledispatch
def _logccdf(
op: Op,
value: TensorVariable,
*inputs: TensorVariable,
**kwargs,
):
"""Create a graph for the log complementary CDF (log survival function) of a ``RandomVariable``.

This function dispatches on the type of ``op``, which should be a subclass
of ``RandomVariable``. If you want to implement new logccdf graphs
for a ``RandomVariable``, register a new function on this dispatcher.

The log complementary CDF is defined as log(1 - CDF(x)), also known as the
log survival function. For distributions with a numerically stable implementation,
this should be used instead of computing log(1 - exp(logcdf)).
"""
raise NotImplementedError(f"LogCCDF method not implemented for {op}")


def _logccdf_helper(rv, value, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make this method do the try except fallback to log1mexp? So users/devs don't need to do it all the time

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done! The helper now tries _logccdf first and automatically falls back to log1mexp(logcdf) if not implemented. Callers no longer need to handle the exception.

15e5f64

"""Helper that calls `_logccdf` dispatcher."""
logccdf = _logccdf(rv.owner.op, value, *rv.owner.inputs, name=rv.name, **kwargs)

if rv.name:
logccdf.name = f"{rv.name}_logccdf"

return logccdf


@singledispatch
def _icdf(
op: Op,
Expand Down
64 changes: 64 additions & 0 deletions pymc/logprob/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from pymc.logprob.abstract import (
MeasurableOp,
_icdf_helper,
_logccdf_helper,
_logcdf_helper,
_logprob,
_logprob_helper,
Expand Down Expand Up @@ -302,6 +303,69 @@ def normal_logcdf(value, mu, sigma):
return expr


def logccdf(rv: TensorVariable, value: TensorLike, warn_rvs=True, **kwargs) -> TensorVariable:
"""Create a graph for the log complementary CDF (log survival function) of a random variable.

The log complementary CDF is defined as log(1 - CDF(x)), also known as the
log survival function. For distributions with a numerically stable implementation,
this is more accurate than computing log(1 - exp(logcdf)).

Parameters
----------
rv : TensorVariable
value : tensor_like
Should be the same type (shape and dtype) as the rv.
warn_rvs : bool, default True
Warn if RVs were found in the logccdf graph.
This can happen when a variable has other random variables as inputs.
In that case, those random variables should be replaced by their respective values.

Returns
-------
logccdf : TensorVariable

Raises
------
RuntimeError
If the logccdf cannot be derived.

Examples
--------
Create a compiled function that evaluates the logccdf of a variable

.. code-block:: python

import pymc as pm
import pytensor.tensor as pt

mu = pt.scalar("mu")
rv = pm.Normal.dist(mu, 1.0)

value = pt.scalar("value")
rv_logccdf = pm.logccdf(rv, value)

# Use .eval() for debugging
print(rv_logccdf.eval({value: 0.9, mu: 0.0})) # -1.5272506

# Compile a function for repeated evaluations
rv_logccdf_fn = pm.compile_pymc([value, mu], rv_logccdf)
print(rv_logccdf_fn(value=0.9, mu=0.0)) # -1.5272506

"""
value = pt.as_tensor_variable(value, dtype=rv.dtype)
try:
return _logccdf_helper(rv, value, **kwargs)
except NotImplementedError:
# Try to rewrite rv
fgraph, _, _ = construct_ir_fgraph({rv: value})
[ir_rv] = fgraph.outputs
expr = _logccdf_helper(ir_rv, value, **kwargs)
[expr] = cleanup_ir([expr])
if warn_rvs:
_warn_rvs_in_inferred_graph([expr])
return expr


def icdf(rv: TensorVariable, value: TensorLike, warn_rvs=True, **kwargs) -> TensorVariable:
"""Create a graph for the inverse CDF of a random variable.

Expand Down
8 changes: 7 additions & 1 deletion pymc/logprob/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from pymc.logprob.abstract import (
MeasurableElemwise,
_logccdf,
_logcdf_helper,
_logprob,
_logprob_helper,
Expand Down Expand Up @@ -95,7 +96,12 @@ def comparison_logprob(op, values, base_rv, operand, **kwargs):
base_rv_op = base_rv.owner.op

logcdf = _logcdf_helper(base_rv, operand, **kwargs)
logccdf = pt.log1mexp(logcdf)
# Try to use a numerically stable logccdf if available, otherwise fall back
# to computing log(1 - exp(logcdf)) which can be unstable in the tails
try:
logccdf = _logccdf(base_rv_op, operand, *base_rv.owner.inputs, **kwargs)
except NotImplementedError:
logccdf = pt.log1mexp(logcdf)

condn_exp = pt.eq(value, np.array(True))

Expand Down
10 changes: 8 additions & 2 deletions pymc/logprob/censoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from pytensor.tensor.math import ceil, clip, floor, round_half_to_even
from pytensor.tensor.variable import TensorConstant

from pymc.logprob.abstract import MeasurableElemwise, _logcdf, _logprob
from pymc.logprob.abstract import MeasurableElemwise, _logccdf, _logcdf, _logprob
from pymc.logprob.rewriting import measurable_ir_rewrites_db
from pymc.logprob.utils import CheckParameterValue, filter_measurable_variables

Expand Down Expand Up @@ -119,7 +119,13 @@ def clip_logprob(op, values, base_rv, lower_bound, upper_bound, **kwargs):
if not (isinstance(upper_bound, TensorConstant) and np.all(np.isinf(upper_bound.value))):
is_upper_bounded = True

logccdf = pt.log1mexp(logcdf)
# Try to use a numerically stable logccdf if available, otherwise fall back
# to computing log(1 - exp(logcdf)) which can be unstable in the tails
try:
logccdf = _logccdf(base_rv_op, value, *base_rv_inputs, **kwargs)
except NotImplementedError:
logccdf = pt.log1mexp(logcdf)

# For right clipped discrete RVs, we need to add an extra term
# corresponding to the pmf at the upper bound
if base_rv.dtype.startswith("int"):
Expand Down
9 changes: 8 additions & 1 deletion pymc/logprob/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@
MeasurableOp,
_icdf,
_icdf_helper,
_logccdf_helper,
_logcdf,
_logcdf_helper,
_logprob,
Expand Down Expand Up @@ -248,9 +249,15 @@ def measurable_transform_logcdf(op: MeasurableTransform, value, *inputs, **kwarg

logcdf = _logcdf_helper(measurable_input, backward_value)
if is_discrete:
# For discrete distributions, use the logcdf at the previous value
logccdf = pt.log1mexp(_logcdf_helper(measurable_input, backward_value - 1))
else:
logccdf = pt.log1mexp(logcdf)
# Try to use a numerically stable logccdf if available, otherwise fall back
# to computing log(1 - exp(logcdf)) which can be unstable in the tails
try:
logccdf = _logccdf_helper(measurable_input, backward_value)
except NotImplementedError:
logccdf = pt.log1mexp(logcdf)

if isinstance(op.scalar_op, MONOTONICALLY_INCREASING_OPS):
pass
Expand Down
69 changes: 67 additions & 2 deletions tests/logprob/test_abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,13 @@

import pymc as pm

from pymc.logprob.abstract import MeasurableElemwise, MeasurableOp, _logcdf_helper
from pymc.logprob.basic import logcdf
from pymc.logprob.abstract import (
MeasurableElemwise,
MeasurableOp,
_logccdf_helper,
_logcdf_helper,
)
from pymc.logprob.basic import logccdf, logcdf


def assert_equal_hash(classA, classB):
Expand Down Expand Up @@ -80,6 +85,38 @@ def test_logcdf_helper():
np.testing.assert_almost_equal(x_logcdf.eval(), sp.norm(0, 1).logcdf([0, 1]))


def test_logccdf_helper():
value = pt.vector("value")
x = pm.Normal.dist(0, 1)

x_logccdf = _logccdf_helper(x, value)
np.testing.assert_almost_equal(x_logccdf.eval({value: [0, 1]}), sp.norm(0, 1).logsf([0, 1]))

x_logccdf = _logccdf_helper(x, [0, 1])
np.testing.assert_almost_equal(x_logccdf.eval(), sp.norm(0, 1).logsf([0, 1]))


def test_logccdf_helper_numerical_stability():
"""Test that logccdf is numerically stable in the far right tail.

This is where log(1 - exp(logcdf)) would lose precision because CDF is very close to 1.
"""
x = pm.Normal.dist(0, 1)

# Test value far in the right tail where CDF is essentially 1
far_tail_value = 10.0

x_logccdf = _logccdf_helper(x, far_tail_value)
result = x_logccdf.eval()

# scipy.stats.norm.logsf uses a numerically stable implementation
expected = sp.norm(0, 1).logsf(far_tail_value)

# The naive computation would give log(1 - 1) = -inf or very wrong values
# The stable implementation should match scipy's logsf closely
np.testing.assert_almost_equal(result, expected, decimal=6)


def test_logcdf_transformed_argument():
with pm.Model() as m:
sigma = pm.HalfFlat("sigma")
Expand All @@ -95,3 +132,31 @@ def test_logcdf_transformed_argument():
pm.TruncatedNormal.dist(0, sigma_value, lower=None, upper=1.0), x_value
).eval()
assert np.isclose(observed, expected)


def test_logccdf():
"""Test the public logccdf function."""
value = pt.vector("value")
x = pm.Normal.dist(0, 1)

x_logccdf = logccdf(x, value)
np.testing.assert_almost_equal(x_logccdf.eval({value: [0, 1]}), sp.norm(0, 1).logsf([0, 1]))


def test_logccdf_numerical_stability():
"""Test that pm.logccdf is numerically stable in the extreme right tail.

For a normal distribution, the log survival function at x=10 is very negative
(around -52). Using log(1 - exp(logcdf)) would fail because CDF(10) is essentially 1.
"""
x = pm.Normal.dist(0, 1)

# Test value far in the right tail
far_tail_value = 10.0

result = logccdf(x, far_tail_value).eval()
expected = sp.norm(0, 1).logsf(far_tail_value)

# Should be around -52, not -inf or nan
assert np.isfinite(result)
np.testing.assert_almost_equal(result, expected, decimal=6)
46 changes: 46 additions & 0 deletions tests/logprob/test_censoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,3 +261,49 @@ def test_rounding(rounding_op):
logprob.eval({xr_vv: test_value}),
expected_logp,
)


@pytest.mark.parametrize(
"censoring_side,bound_value",
[
("right", 40.0), # Far right tail: CDF ≈ 1, need stable log(1-CDF)
("left", -40.0), # Far left tail: CDF ≈ 0, need stable log(CDF)
],
)
def test_censored_logprob_numerical_stability(censoring_side, bound_value):
"""Test that censored distributions use numerically stable log-probability computations.

For right-censoring at the upper bound, log(1 - CDF) is computed. When CDF ≈ 1
(far right tail), this requires a stable logccdf implementation.

For left-censoring at the lower bound, log(CDF) is computed. When CDF ≈ 0
(far left tail), this requires a stable logcdf implementation.

This test uses pm.Censored which is the high-level API for censored distributions.
"""
import pymc as pm
Copy link

Copilot AI Dec 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Module 'pymc' is imported with both 'import' and 'import from'.

Copilot uses AI. Check for mistakes.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. Moved import pymc as pm to top-level imports.

19b9979


ref_scipy = st.norm(0, 1)

with pm.Model() as model:
normal_dist = pm.Normal.dist(mu=0.0, sigma=1.0)
if censoring_side == "right":
pm.Censored("y", normal_dist, lower=None, upper=bound_value)
expected_logp = ref_scipy.logsf(bound_value) # log(1 - CDF)
else: # left
pm.Censored("y", normal_dist, lower=bound_value, upper=None)
expected_logp = ref_scipy.logcdf(bound_value) # log(CDF)

# Compile the logp function
logp_fn = model.compile_logp()

# Evaluate at the bound - this is where the log survival/cdf function is used
logp_at_bound = logp_fn({"y": bound_value})

# This should be finite and correct, not -inf
assert np.isfinite(logp_at_bound), (
f"logp at {censoring_side} bound should be finite, got {logp_at_bound}"
)
assert np.isclose(logp_at_bound, expected_logp, rtol=1e-6), (
f"logp at {censoring_side} bound: got {logp_at_bound}, expected {expected_logp}"
)