@@ -86,12 +86,27 @@ def test_logcdf_helper():
8686
8787
8888def test_logccdf_helper ():
89+ """Test the internal _logccdf_helper function for basic correctness.
90+
91+ What: Tests that _logccdf_helper correctly computes log(1 - CDF(x)),
92+ also known as the log survival function (logsf).
93+
94+ Why: The _logccdf_helper is the internal dispatcher that routes logccdf
95+ computations to distribution-specific implementations. It needs to work
96+ with both symbolic (TensorVariable) and concrete values.
97+
98+ How: Creates a Normal(0, 1) distribution and computes logccdf at values
99+ [0, 1]. Compares against scipy's logsf which is the reference implementation.
100+ Tests both symbolic input (pt.vector) and concrete input ([0, 1]).
101+ """
89102 value = pt .vector ("value" )
90103 x = pm .Normal .dist (0 , 1 )
91104
105+ # Test with symbolic value input
92106 x_logccdf = _logccdf_helper (x , value )
93107 np .testing .assert_almost_equal (x_logccdf .eval ({value : [0 , 1 ]}), sp .norm (0 , 1 ).logsf ([0 , 1 ]))
94108
109+ # Test with concrete value input
95110 x_logccdf = _logccdf_helper (x , [0 , 1 ])
96111 np .testing .assert_almost_equal (x_logccdf .eval (), sp .norm (0 , 1 ).logsf ([0 , 1 ]))
97112
@@ -114,7 +129,20 @@ def test_logcdf_transformed_argument():
114129
115130
116131def test_logccdf ():
117- """Test the public logccdf function."""
132+ """Test the public pm.logccdf function for basic correctness.
133+
134+ What: Tests that the public logccdf API correctly computes the log
135+ complementary CDF (log survival function) for a Normal distribution.
136+
137+ Why: pm.logccdf is the user-facing function that wraps _logccdf_helper
138+ and handles IR graph rewriting when needed. It should produce correct
139+ results for direct RandomVariable inputs.
140+
141+ How: Creates Normal(0, 1), computes logccdf at [0, 1], and compares
142+ against scipy.stats.norm.logsf reference values.
143+ - logsf(0) = log(0.5) ≈ -0.693 (50% probability of exceeding 0)
144+ - logsf(1) ≈ -1.84 (about 15.9% probability of exceeding 1)
145+ """
118146 value = pt .vector ("value" )
119147 x = pm .Normal .dist (0 , 1 )
120148
0 commit comments