Skip to content

Commit 628e6d5

Browse files
committed
Add test for logccdf IR graph rewriting path
Tests that pm.logccdf works when the random variable depends on transformed parameters, triggering the construct_ir_fgraph fallback path in the public logccdf function.
1 parent 63c9327 commit 628e6d5

File tree

1 file changed

+43
-0
lines changed

1 file changed

+43
-0
lines changed

tests/logprob/test_abstract.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,3 +226,46 @@ def graph_contains_log1mexp(var, depth=0, visited=None):
226226
assert not graph_contains_log1mexp(normal_logccdf), (
227227
"Normal logccdf should use specialized implementation"
228228
)
229+
230+
231+
def test_logccdf_transformed_argument():
232+
"""Test logccdf with a transformed random variable requiring IR graph rewriting.
233+
234+
What: Tests that pm.logccdf works when the random variable has been
235+
transformed (e.g., sigma is log-transformed), which requires the
236+
IR (intermediate representation) graph rewriting path.
237+
238+
Why: When a random variable depends on transformed parameters, the
239+
direct _logccdf_helper call fails because the RV isn't in the expected
240+
form. The public logccdf function catches this and rewrites the graph
241+
using construct_ir_fgraph to make it work. This test ensures that
242+
fallback path is covered and correct.
243+
244+
How:
245+
1. Creates a model where x ~ Normal(0, sigma) with sigma ~ HalfFlat
246+
(HalfFlat gets log-transformed automatically)
247+
2. Adds a Potential using logccdf(x, 1.0)
248+
3. Compiles and evaluates the model's logp
249+
4. Verifies the result equals:
250+
logp(Normal(0, sigma), x_value) + logsf(1.0; 0, sigma)
251+
252+
The IR rewriting is triggered because x's distribution depends on
253+
the transformed sigma parameter.
254+
"""
255+
with pm.Model() as m:
256+
sigma = pm.HalfFlat("sigma")
257+
x = pm.Normal("x", 0, sigma)
258+
pm.Potential("norm_term", logccdf(x, 1.0))
259+
260+
sigma_value_log = -1.0
261+
sigma_value = np.exp(sigma_value_log) # sigma ≈ 0.368
262+
x_value = 0.5
263+
264+
observed = m.compile_logp(jacobian=False)({"sigma_log__": sigma_value_log, "x": x_value})
265+
266+
# Expected = logp(x | sigma) + logccdf(Normal(0, sigma), 1.0)
267+
expected_logp = pm.logp(pm.Normal.dist(0, sigma_value), x_value).eval()
268+
expected_logsf = sp.norm(0, sigma_value).logsf(1.0)
269+
expected = expected_logp + expected_logsf
270+
271+
assert np.isclose(observed, expected)

0 commit comments

Comments
 (0)