-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Add _logccdf dispatcher for numerically stable log survival function in censored distributions
#7996
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add _logccdf dispatcher for numerically stable log survival function in censored distributions
#7996
Changes from 2 commits
87eba5d
81da946
15e5f64
15806c0
063af42
19b9979
20322a1
93734c2
2df5274
63c9327
628e6d5
36b8672
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -108,6 +108,36 @@ def _logcdf_helper(rv, value, **kwargs): | |
| return logcdf | ||
|
|
||
|
|
||
| @singledispatch | ||
| def _logccdf( | ||
| op: Op, | ||
| value: TensorVariable, | ||
| *inputs: TensorVariable, | ||
| **kwargs, | ||
| ): | ||
| """Create a graph for the log complementary CDF (log survival function) of a ``RandomVariable``. | ||
|
|
||
| This function dispatches on the type of ``op``, which should be a subclass | ||
| of ``RandomVariable``. If you want to implement new logccdf graphs | ||
| for a ``RandomVariable``, register a new function on this dispatcher. | ||
|
|
||
| The log complementary CDF is defined as log(1 - CDF(x)), also known as the | ||
| log survival function. For distributions with a numerically stable implementation, | ||
| this should be used instead of computing log(1 - exp(logcdf)). | ||
| """ | ||
| raise NotImplementedError(f"LogCCDF method not implemented for {op}") | ||
|
|
||
|
|
||
| def _logccdf_helper(rv, value, **kwargs): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Make this method do the try except fallback to log1mexp? So users/devs don't need to do it all the time
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done! The helper now tries → 15e5f64 |
||
| """Helper that calls `_logccdf` dispatcher.""" | ||
| logccdf = _logccdf(rv.owner.op, value, *rv.owner.inputs, name=rv.name, **kwargs) | ||
|
|
||
| if rv.name: | ||
| logccdf.name = f"{rv.name}_logccdf" | ||
|
|
||
| return logccdf | ||
|
|
||
|
|
||
| @singledispatch | ||
| def _icdf( | ||
| op: Op, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -261,3 +261,49 @@ def test_rounding(rounding_op): | |
| logprob.eval({xr_vv: test_value}), | ||
| expected_logp, | ||
| ) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "censoring_side,bound_value", | ||
| [ | ||
| ("right", 40.0), # Far right tail: CDF ≈ 1, need stable log(1-CDF) | ||
| ("left", -40.0), # Far left tail: CDF ≈ 0, need stable log(CDF) | ||
| ], | ||
| ) | ||
| def test_censored_logprob_numerical_stability(censoring_side, bound_value): | ||
| """Test that censored distributions use numerically stable log-probability computations. | ||
|
|
||
| For right-censoring at the upper bound, log(1 - CDF) is computed. When CDF ≈ 1 | ||
| (far right tail), this requires a stable logccdf implementation. | ||
|
|
||
| For left-censoring at the lower bound, log(CDF) is computed. When CDF ≈ 0 | ||
| (far left tail), this requires a stable logcdf implementation. | ||
|
|
||
| This test uses pm.Censored which is the high-level API for censored distributions. | ||
| """ | ||
| import pymc as pm | ||
|
||
|
|
||
| ref_scipy = st.norm(0, 1) | ||
|
|
||
| with pm.Model() as model: | ||
| normal_dist = pm.Normal.dist(mu=0.0, sigma=1.0) | ||
| if censoring_side == "right": | ||
| pm.Censored("y", normal_dist, lower=None, upper=bound_value) | ||
| expected_logp = ref_scipy.logsf(bound_value) # log(1 - CDF) | ||
| else: # left | ||
| pm.Censored("y", normal_dist, lower=bound_value, upper=None) | ||
| expected_logp = ref_scipy.logcdf(bound_value) # log(CDF) | ||
|
|
||
| # Compile the logp function | ||
| logp_fn = model.compile_logp() | ||
|
|
||
| # Evaluate at the bound - this is where the log survival/cdf function is used | ||
| logp_at_bound = logp_fn({"y": bound_value}) | ||
|
|
||
| # This should be finite and correct, not -inf | ||
| assert np.isfinite(logp_at_bound), ( | ||
| f"logp at {censoring_side} bound should be finite, got {logp_at_bound}" | ||
| ) | ||
| assert np.isclose(logp_at_bound, expected_logp, rtol=1e-6), ( | ||
| f"logp at {censoring_side} bound: got {logp_at_bound}, expected {expected_logp}" | ||
| ) | ||
Uh oh!
There was an error while loading. Please reload this page.