|
44 | 44 | from pymc.distributions.transforms import _default_transform |
45 | 45 | from pymc.exceptions import TruncationError |
46 | 46 | from pymc.logprob.abstract import _logcdf, _logprob |
47 | | -from pymc.logprob.basic import icdf, logcdf, logp |
| 47 | +from pymc.logprob.basic import icdf, logccdf, logcdf, logp |
48 | 48 | from pymc.math import logdiffexp |
49 | 49 | from pymc.pytensorf import collect_default_updates |
50 | 50 | from pymc.util import check_dist_not_registered |
@@ -211,6 +211,23 @@ def _create_logcdf_exprs( |
211 | 211 | upper_logcdf = graph_replace(lower_logcdf, {lower_value: upper_value}) |
212 | 212 | return lower_logcdf, upper_logcdf |
213 | 213 |
|
| 214 | + @staticmethod |
| 215 | + def _create_lower_logccdf_expr( |
| 216 | + base_rv: TensorVariable, |
| 217 | + value: TensorVariable, |
| 218 | + lower: TensorVariable, |
| 219 | + ) -> TensorVariable: |
| 220 | + """Create logccdf expression at lower bound for base_rv. |
| 221 | +
|
| 222 | + Uses `value` as a template for broadcasting. This is numerically more |
| 223 | + stable than computing log(1 - exp(logcdf)) for distributions that have |
| 224 | + a registered logccdf method. |
| 225 | + """ |
| 226 | + # For left truncated discrete RVs, we need to include the whole lower bound. |
| 227 | + lower_value = lower - 1 if base_rv.type.dtype.startswith("int") else lower |
| 228 | + lower_value = pt.full_like(value, lower_value, dtype=config.floatX) |
| 229 | + return logccdf(base_rv, lower_value, warn_rvs=False) |
| 230 | + |
214 | 231 | def update(self, node: Apply): |
215 | 232 | """Return the update mapping for the internal RNGs. |
216 | 233 |
|
@@ -401,7 +418,8 @@ def truncated_logprob(op, values, *inputs, **kwargs): |
401 | 418 | if is_lower_bounded and is_upper_bounded: |
402 | 419 | lognorm = logdiffexp(upper_logcdf, lower_logcdf) |
403 | 420 | elif is_lower_bounded: |
404 | | - lognorm = pt.log1mexp(lower_logcdf) |
| 421 | + # Use numerically stable logccdf instead of log(1 - exp(logcdf)) |
| 422 | + lognorm = TruncatedRV._create_lower_logccdf_expr(base_rv, value, lower) |
405 | 423 | elif is_upper_bounded: |
406 | 424 | lognorm = upper_logcdf |
407 | 425 |
|
@@ -438,7 +456,8 @@ def truncated_logcdf(op: TruncatedRV, value, *inputs, **kwargs): |
438 | 456 | if is_lower_bounded and is_upper_bounded: |
439 | 457 | lognorm = logdiffexp(upper_logcdf, lower_logcdf) |
440 | 458 | elif is_lower_bounded: |
441 | | - lognorm = pt.log1mexp(lower_logcdf) |
| 459 | + # Use numerically stable logccdf instead of log(1 - exp(logcdf)) |
| 460 | + lognorm = TruncatedRV._create_lower_logccdf_expr(base_rv, value, lower) |
442 | 461 | elif is_upper_bounded: |
443 | 462 | lognorm = upper_logcdf |
444 | 463 |
|
|
0 commit comments