@@ -224,3 +224,46 @@ def graph_contains_log1mexp(var, depth=0, visited=None):
224224 normal_rv = pm .Normal .dist (0 , 1 )
225225 normal_logccdf = _logccdf_helper (normal_rv , 0.5 )
226226 assert not graph_contains_log1mexp (normal_logccdf ), "Normal logccdf should use specialized implementation"
227+
228+
229+ def test_logccdf_transformed_argument ():
230+ """Test logccdf with a transformed random variable requiring IR graph rewriting.
231+
232+ What: Tests that pm.logccdf works when the random variable has been
233+ transformed (e.g., sigma is log-transformed), which requires the
234+ IR (intermediate representation) graph rewriting path.
235+
236+ Why: When a random variable depends on transformed parameters, the
237+ direct _logccdf_helper call fails because the RV isn't in the expected
238+ form. The public logccdf function catches this and rewrites the graph
239+ using construct_ir_fgraph to make it work. This test ensures that
240+ fallback path is covered and correct.
241+
242+ How:
243+ 1. Creates a model where x ~ Normal(0, sigma) with sigma ~ HalfFlat
244+ (HalfFlat gets log-transformed automatically)
245+ 2. Adds a Potential using logccdf(x, 1.0)
246+ 3. Compiles and evaluates the model's logp
247+ 4. Verifies the result equals:
248+ logp(Normal(0, sigma), x_value) + logsf(1.0; 0, sigma)
249+
250+ The IR rewriting is triggered because x's distribution depends on
251+ the transformed sigma parameter.
252+ """
253+ with pm .Model () as m :
254+ sigma = pm .HalfFlat ("sigma" )
255+ x = pm .Normal ("x" , 0 , sigma )
256+ pm .Potential ("norm_term" , logccdf (x , 1.0 ))
257+
258+ sigma_value_log = - 1.0
259+ sigma_value = np .exp (sigma_value_log ) # sigma ≈ 0.368
260+ x_value = 0.5
261+
262+ observed = m .compile_logp (jacobian = False )({"sigma_log__" : sigma_value_log , "x" : x_value })
263+
264+ # Expected = logp(x | sigma) + logccdf(Normal(0, sigma), 1.0)
265+ expected_logp = pm .logp (pm .Normal .dist (0 , sigma_value ), x_value ).eval ()
266+ expected_logsf = sp .norm (0 , sigma_value ).logsf (1.0 )
267+ expected = expected_logp + expected_logsf
268+
269+ assert np .isclose (observed , expected )
0 commit comments