Skip to content

Commit 15806c0

Browse files
committed
Add _logccdf support to Truncated distribution
Uses stable logccdf for computing log(1 - CDF(lower)) in truncated_logprob and truncated_logcdf instead of the potentially unstable log1mexp(logcdf).
1 parent 15e5f64 commit 15806c0

File tree

1 file changed

+22
-3
lines changed

1 file changed

+22
-3
lines changed

pymc/distributions/truncated.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
from pymc.distributions.transforms import _default_transform
4545
from pymc.exceptions import TruncationError
4646
from pymc.logprob.abstract import _logcdf, _logprob
47-
from pymc.logprob.basic import icdf, logcdf, logp
47+
from pymc.logprob.basic import icdf, logccdf, logcdf, logp
4848
from pymc.math import logdiffexp
4949
from pymc.pytensorf import collect_default_updates
5050
from pymc.util import check_dist_not_registered
@@ -211,6 +211,23 @@ def _create_logcdf_exprs(
211211
upper_logcdf = graph_replace(lower_logcdf, {lower_value: upper_value})
212212
return lower_logcdf, upper_logcdf
213213

214+
@staticmethod
215+
def _create_lower_logccdf_expr(
216+
base_rv: TensorVariable,
217+
value: TensorVariable,
218+
lower: TensorVariable,
219+
) -> TensorVariable:
220+
"""Create logccdf expression at lower bound for base_rv.
221+
222+
Uses `value` as a template for broadcasting. This is numerically more
223+
stable than computing log(1 - exp(logcdf)) for distributions that have
224+
a registered logccdf method.
225+
"""
226+
# For left truncated discrete RVs, we need to include the whole lower bound.
227+
lower_value = lower - 1 if base_rv.type.dtype.startswith("int") else lower
228+
lower_value = pt.full_like(value, lower_value, dtype=config.floatX)
229+
return logccdf(base_rv, lower_value, warn_rvs=False)
230+
214231
def update(self, node: Apply):
215232
"""Return the update mapping for the internal RNGs.
216233
@@ -401,7 +418,8 @@ def truncated_logprob(op, values, *inputs, **kwargs):
401418
if is_lower_bounded and is_upper_bounded:
402419
lognorm = logdiffexp(upper_logcdf, lower_logcdf)
403420
elif is_lower_bounded:
404-
lognorm = pt.log1mexp(lower_logcdf)
421+
# Use numerically stable logccdf instead of log(1 - exp(logcdf))
422+
lognorm = TruncatedRV._create_lower_logccdf_expr(base_rv, value, lower)
405423
elif is_upper_bounded:
406424
lognorm = upper_logcdf
407425

@@ -438,7 +456,8 @@ def truncated_logcdf(op: TruncatedRV, value, *inputs, **kwargs):
438456
if is_lower_bounded and is_upper_bounded:
439457
lognorm = logdiffexp(upper_logcdf, lower_logcdf)
440458
elif is_lower_bounded:
441-
lognorm = pt.log1mexp(lower_logcdf)
459+
# Use numerically stable logccdf instead of log(1 - exp(logcdf))
460+
lognorm = TruncatedRV._create_lower_logccdf_expr(base_rv, value, lower)
442461
elif is_upper_bounded:
443462
lognorm = upper_logcdf
444463

0 commit comments

Comments
 (0)