1919
2020from pymc .distributions import Dirichlet , Normal
2121from pymc .distributions .transforms import log
22- from pymc .model import Model
22+ from pymc .model import Deterministic , Model
2323from pymc .stats .log_density import compute_log_likelihood , compute_log_prior
2424from tests .distributions .test_multivariate import dirichlet_logpdf
2525
@@ -41,7 +41,7 @@ def test_basic(self, transform):
4141 assert m .rvs_to_transforms [x ] is transform
4242
4343 assert res is idata
44- assert res .log_likelihood .dims == {"chain" : 4 , "draw" : 25 , "test_dim" : 3 }
44+ assert res .log_likelihood .sizes == {"chain" : 4 , "draw" : 25 , "test_dim" : 3 }
4545
4646 np .testing .assert_allclose (
4747 res .log_likelihood ["y" ].values ,
@@ -62,7 +62,7 @@ def test_multivariate(self):
6262 idata = InferenceData (posterior = dict_to_dataset ({"p" : p_draws }))
6363 res = compute_log_likelihood (idata )
6464
65- assert res .log_likelihood .dims == {"chain" : 4 , "draw" : 25 , "test_event_dim" : 10 }
65+ assert res .log_likelihood .sizes == {"chain" : 4 , "draw" : 25 , "test_event_dim" : 10 }
6666
6767 np .testing .assert_allclose (
6868 res .log_likelihood ["y" ].values ,
@@ -149,7 +149,26 @@ def test_basic_log_prior(self, transform):
149149 assert m .rvs_to_transforms [x ] is transform
150150
151151 assert res is idata
152- assert res .log_prior .dims == {"chain" : 4 , "draw" : 25 }
152+ assert res .log_prior .sizes == {"chain" : 4 , "draw" : 25 }
153+
154+ np .testing .assert_allclose (
155+ res .log_prior ["x" ].values ,
156+ st .norm (0 , 1 ).logpdf (idata .posterior ["x" ].values ),
157+ )
158+
159+ def test_deterministic_log_prior (self ):
160+ with Model () as m :
161+ x = Normal ("x" )
162+ Deterministic ("d" , 2 * x )
163+ Normal ("y" , x , observed = [0 , 1 , 2 ])
164+
165+ idata = InferenceData (posterior = dict_to_dataset ({"x" : np .arange (100 ).reshape (4 , 25 )}))
166+ res = compute_log_prior (idata )
167+
168+ assert res is idata
169+ assert "x" in res .log_prior
170+ assert "d" not in res .log_prior
171+ assert res .log_prior .sizes == {"chain" : 4 , "draw" : 25 }
153172
154173 np .testing .assert_allclose (
155174 res .log_prior ["x" ].values ,
0 commit comments