Skip to content

Commit a1bfef8

Browse files
committed
Add test for logccdf IR graph rewriting path
Tests that pm.logccdf works when the random variable depends on transformed parameters, triggering the construct_ir_fgraph fallback path in the public logccdf function.
1 parent 38b9335 commit a1bfef8

File tree

1 file changed

+43
-0
lines changed

1 file changed

+43
-0
lines changed

tests/logprob/test_abstract.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,3 +224,46 @@ def graph_contains_log1mexp(var, depth=0, visited=None):
224224
normal_rv = pm.Normal.dist(0, 1)
225225
normal_logccdf = _logccdf_helper(normal_rv, 0.5)
226226
assert not graph_contains_log1mexp(normal_logccdf), "Normal logccdf should use specialized implementation"
227+
228+
229+
def test_logccdf_transformed_argument():
230+
"""Test logccdf with a transformed random variable requiring IR graph rewriting.
231+
232+
What: Tests that pm.logccdf works when the random variable has been
233+
transformed (e.g., sigma is log-transformed), which requires the
234+
IR (intermediate representation) graph rewriting path.
235+
236+
Why: When a random variable depends on transformed parameters, the
237+
direct _logccdf_helper call fails because the RV isn't in the expected
238+
form. The public logccdf function catches this and rewrites the graph
239+
using construct_ir_fgraph to make it work. This test ensures that
240+
fallback path is covered and correct.
241+
242+
How:
243+
1. Creates a model where x ~ Normal(0, sigma) with sigma ~ HalfFlat
244+
(HalfFlat gets log-transformed automatically)
245+
2. Adds a Potential using logccdf(x, 1.0)
246+
3. Compiles and evaluates the model's logp
247+
4. Verifies the result equals:
248+
logp(Normal(0, sigma), x_value) + logsf(1.0; 0, sigma)
249+
250+
The IR rewriting is triggered because x's distribution depends on
251+
the transformed sigma parameter.
252+
"""
253+
with pm.Model() as m:
254+
sigma = pm.HalfFlat("sigma")
255+
x = pm.Normal("x", 0, sigma)
256+
pm.Potential("norm_term", logccdf(x, 1.0))
257+
258+
sigma_value_log = -1.0
259+
sigma_value = np.exp(sigma_value_log) # sigma ≈ 0.368
260+
x_value = 0.5
261+
262+
observed = m.compile_logp(jacobian=False)({"sigma_log__": sigma_value_log, "x": x_value})
263+
264+
# Expected = logp(x | sigma) + logccdf(Normal(0, sigma), 1.0)
265+
expected_logp = pm.logp(pm.Normal.dist(0, sigma_value), x_value).eval()
266+
expected_logsf = sp.norm(0, sigma_value).logsf(1.0)
267+
expected = expected_logp + expected_logsf
268+
269+
assert np.isclose(observed, expected)

0 commit comments

Comments
 (0)