Skip to content

Commit 93734c2

Browse files
committed
Add test for _logccdf_helper fallback to log1mexp
Verifies that distributions without a registered _logccdf method (e.g., Uniform) use the log1mexp(logcdf) fallback, while distributions with _logccdf (e.g., Normal) use their specialized implementation. The test inspects the computation graph structure rather than just numerical results to ensure the correct code path is exercised.
1 parent 20322a1 commit 93734c2

File tree

1 file changed

+45
-0
lines changed

1 file changed

+45
-0
lines changed

tests/logprob/test_abstract.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,3 +139,48 @@ def test_logccdf_numerical_stability():
139139
# Should be around -52, not -inf or nan
140140
assert np.isfinite(result)
141141
np.testing.assert_almost_equal(result, expected, decimal=6)
142+
143+
144+
def test_logccdf_helper_fallback():
145+
"""Test that _logccdf_helper falls back to log1mexp(logcdf) for distributions without logccdf.
146+
147+
What: Verifies that the helper's NotImplementedError fallback branch is exercised
148+
and produces the correct graph structure.
149+
150+
Why: Distributions without a registered _logccdf method should still work via
151+
the fallback computation log(1 - exp(logcdf)) = log1mexp(logcdf).
152+
153+
How: Uses Uniform distribution (which has logcdf but no logccdf) and inspects
154+
the resulting computation graph. For Uniform, the graph should contain log1mexp.
155+
For Normal (which has logccdf), the graph should NOT contain log1mexp.
156+
"""
157+
from pytensor.scalar.math import Log1mexp
158+
from pytensor.tensor.elemwise import Elemwise
159+
160+
def graph_contains_log1mexp(var, depth=0, visited=None):
161+
"""Recursively check if computation graph contains Log1mexp scalar op."""
162+
if visited is None:
163+
visited = set()
164+
if id(var) in visited or depth > 20:
165+
return False
166+
visited.add(id(var))
167+
if var.owner:
168+
op = var.owner.op
169+
if isinstance(op, Elemwise) and isinstance(op.scalar_op, Log1mexp):
170+
return True
171+
for inp in var.owner.inputs:
172+
if graph_contains_log1mexp(inp, depth + 1, visited):
173+
return True
174+
return False
175+
176+
# Uniform has logcdf but no logccdf - should use log1mexp fallback
177+
uniform_rv = pm.Uniform.dist(0, 1)
178+
uniform_logccdf = _logccdf_helper(uniform_rv, 0.5)
179+
assert graph_contains_log1mexp(uniform_logccdf), "Uniform logccdf should use log1mexp fallback"
180+
181+
# Normal has logccdf - should NOT use log1mexp fallback
182+
normal_rv = pm.Normal.dist(0, 1)
183+
normal_logccdf = _logccdf_helper(normal_rv, 0.5)
184+
assert not graph_contains_log1mexp(normal_logccdf), (
185+
"Normal logccdf should use specialized implementation"
186+
)

0 commit comments

Comments
 (0)