@@ -388,58 +388,6 @@ def scan_fn(mus_t, sigma_t, Y_t_val, S_t_val, Gamma_t):
388388 assert np .allclose (y_logp_val , y_logp_ref_val )
389389
390390
391- @pytest .mark .xfail (reason = "see #148" )
392- @pytensor .config .change_flags (compute_test_value = "raise" )
393- @pytest .mark .xfail (reason = "see #148" )
394- def test_initial_values ():
395- srng = pt .random .RandomStream (seed = 2320 )
396-
397- p_S_0 = np .array ([0.9 , 0.1 ])
398- S_0_rv = srng .categorical (p_S_0 , name = "S_0" )
399- S_0_rv .tag .test_value = 0
400-
401- Gamma_at = pt .matrix ("Gamma" )
402- Gamma_at .tag .test_value = np .array ([[0 , 1 ], [1 , 0 ]])
403-
404- s_0_vv = S_0_rv .clone ()
405- s_0_vv .name = "s_0"
406-
407- def step_fn (S_tm1 , Gamma ):
408- S_t = srng .categorical (Gamma [S_tm1 ], name = "S_t" )
409- return S_t
410-
411- S_1T_rv , _ = pytensor .scan (
412- fn = step_fn ,
413- outputs_info = [{"initial" : S_0_rv , "taps" : [- 1 ]}],
414- non_sequences = [Gamma_at ],
415- strict = True ,
416- n_steps = 10 ,
417- name = "S_0T" ,
418- )
419-
420- S_1T_rv .name = "S_1T"
421- s_1T_vv = S_1T_rv .clone ()
422- s_1T_vv .name = "s_1T"
423-
424- logp_parts = conditional_logp ({S_1T_rv : s_1T_vv , S_0_rv : s_0_vv })
425-
426- s_0_val = 0
427- s_1T_val = np .array ([1 , 0 , 1 , 0 , 1 , 1 , 0 , 1 , 0 , 1 ])
428- Gamma_val = np .array ([[0.1 , 0.9 ], [0.9 , 0.1 ]])
429-
430- exp_res = np .log (p_S_0 [s_0_val ])
431- s_prev = s_0_val
432- for s in s_1T_val :
433- exp_res += np .log (Gamma_val [s_prev , s ])
434- s_prev = s
435-
436- S_0T_logp = sum (v .sum () for v in logp_parts .values ())
437- S_0T_logp_fn = pytensor .function ([s_0_vv , s_1T_vv , Gamma_at ], S_0T_logp )
438- res = S_0T_logp_fn (s_0_val , s_1T_val , Gamma_val )
439-
440- assert res == pytest .approx (exp_res )
441-
442-
443391@pytest .mark .parametrize ("remove_asserts" , (True , False ))
444392def test_mode_is_kept (remove_asserts ):
445393 mode = Mode ().including ("local_remove_all_assert" ) if remove_asserts else None
0 commit comments