Skip to content

Commit 15e5f64

Browse files
committed
Move try/except fallback into _logccdf_helper
Centralizes the fallback logic so callers don't need to handle it. The helper now tries the stable _logccdf first and automatically falls back to log1mexp(logcdf) if not implemented.
1 parent 81da946 commit 15e5f64

File tree

4 files changed

+20
-22
lines changed

4 files changed

+20
-22
lines changed

pymc/logprob/abstract.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@
4040
from collections.abc import Sequence
4141
from functools import singledispatch
4242

43+
import pytensor.tensor as pt
44+
4345
from pytensor.graph import Apply, Op, Variable
4446
from pytensor.graph.utils import MetaType
4547
from pytensor.tensor import TensorVariable
@@ -129,8 +131,17 @@ def _logccdf(
129131

130132

131133
def _logccdf_helper(rv, value, **kwargs):
132-
"""Helper that calls `_logccdf` dispatcher."""
133-
logccdf = _logccdf(rv.owner.op, value, *rv.owner.inputs, name=rv.name, **kwargs)
134+
"""Helper that calls `_logccdf` dispatcher with fallback to log1mexp(logcdf).
135+
136+
If a numerically stable `_logccdf` implementation is registered for the
137+
distribution, it will be used. Otherwise, falls back to computing
138+
`log(1 - exp(logcdf))` which may be numerically unstable in the tails.
139+
"""
140+
try:
141+
logccdf = _logccdf(rv.owner.op, value, *rv.owner.inputs, name=rv.name, **kwargs)
142+
except NotImplementedError:
143+
logcdf = _logcdf_helper(rv, value, **kwargs)
144+
logccdf = pt.log1mexp(logcdf)
134145

135146
if rv.name:
136147
logccdf.name = f"{rv.name}_logccdf"

pymc/logprob/binary.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
from pymc.logprob.abstract import (
2727
MeasurableElemwise,
28-
_logccdf,
28+
_logccdf_helper,
2929
_logcdf_helper,
3030
_logprob,
3131
_logprob_helper,
@@ -96,12 +96,7 @@ def comparison_logprob(op, values, base_rv, operand, **kwargs):
9696
base_rv_op = base_rv.owner.op
9797

9898
logcdf = _logcdf_helper(base_rv, operand, **kwargs)
99-
# Try to use a numerically stable logccdf if available, otherwise fall back
100-
# to computing log(1 - exp(logcdf)) which can be unstable in the tails
101-
try:
102-
logccdf = _logccdf(base_rv_op, operand, *base_rv.owner.inputs, **kwargs)
103-
except NotImplementedError:
104-
logccdf = pt.log1mexp(logcdf)
99+
logccdf = _logccdf_helper(base_rv, operand, **kwargs)
105100

106101
condn_exp = pt.eq(value, np.array(True))
107102

pymc/logprob/censoring.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
from pytensor.tensor.math import ceil, clip, floor, round_half_to_even
4848
from pytensor.tensor.variable import TensorConstant
4949

50-
from pymc.logprob.abstract import MeasurableElemwise, _logccdf, _logcdf, _logprob
50+
from pymc.logprob.abstract import MeasurableElemwise, _logccdf_helper, _logcdf, _logprob
5151
from pymc.logprob.rewriting import measurable_ir_rewrites_db
5252
from pymc.logprob.utils import CheckParameterValue, filter_measurable_variables
5353

@@ -119,12 +119,8 @@ def clip_logprob(op, values, base_rv, lower_bound, upper_bound, **kwargs):
119119
if not (isinstance(upper_bound, TensorConstant) and np.all(np.isinf(upper_bound.value))):
120120
is_upper_bounded = True
121121

122-
# Try to use a numerically stable logccdf if available, otherwise fall back
123-
# to computing log(1 - exp(logcdf)) which can be unstable in the tails
124-
try:
125-
logccdf = _logccdf(base_rv_op, value, *base_rv_inputs, **kwargs)
126-
except NotImplementedError:
127-
logccdf = pt.log1mexp(logcdf)
122+
# Use numerically stable logccdf (falls back to log1mexp if not available)
123+
logccdf = _logccdf_helper(base_rv, value, **kwargs)
128124

129125
# For right clipped discrete RVs, we need to add an extra term
130126
# corresponding to the pmf at the upper bound

pymc/logprob/transforms.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -252,12 +252,8 @@ def measurable_transform_logcdf(op: MeasurableTransform, value, *inputs, **kwarg
252252
# For discrete distributions, use the logcdf at the previous value
253253
logccdf = pt.log1mexp(_logcdf_helper(measurable_input, backward_value - 1))
254254
else:
255-
# Try to use a numerically stable logccdf if available, otherwise fall back
256-
# to computing log(1 - exp(logcdf)) which can be unstable in the tails
257-
try:
258-
logccdf = _logccdf_helper(measurable_input, backward_value)
259-
except NotImplementedError:
260-
logccdf = pt.log1mexp(logcdf)
255+
# Use numerically stable logccdf (falls back to log1mexp if not available)
256+
logccdf = _logccdf_helper(measurable_input, backward_value)
261257

262258
if isinstance(op.scalar_op, MONOTONICALLY_INCREASING_OPS):
263259
pass

0 commit comments

Comments
 (0)