@@ -62,30 +62,37 @@ class SamplerWarning:
6262
6363
6464def run_convergence_checks (idata : arviz .InferenceData , model ) -> list [SamplerWarning ]:
65+ warnings : list [SamplerWarning ] = []
66+
6567 if not hasattr (idata , "posterior" ):
6668 msg = "No posterior samples. Unable to run convergence checks"
6769 warn = SamplerWarning (WarningType .BAD_PARAMS , msg , "info" , None , None , None )
68- return [warn ]
70+ warnings .append (warn )
71+ return warnings
72+
73+ warnings += warn_divergences (idata )
74+ warnings += warn_treedepth (idata )
6975
7076 if idata ["posterior" ].sizes ["draw" ] < 100 :
7177 msg = "The number of samples is too small to check convergence reliably."
7278 warn = SamplerWarning (WarningType .BAD_PARAMS , msg , "info" , None , None , None )
73- return [warn ]
79+ warnings .append (warn )
80+ return warnings
7481
7582 if idata ["posterior" ].sizes ["chain" ] == 1 :
7683 msg = "Only one chain was sampled, this makes it impossible to run some convergence checks"
7784 warn = SamplerWarning (WarningType .BAD_PARAMS , msg , "info" )
78- return [warn ]
85+ warnings .append (warn )
86+ return warnings
7987
8088 elif idata ["posterior" ].sizes ["chain" ] < 4 :
8189 msg = (
8290 "We recommend running at least 4 chains for robust computation of "
8391 "convergence diagnostics"
8492 )
8593 warn = SamplerWarning (WarningType .BAD_PARAMS , msg , "info" )
86- return [ warn ]
94+ warnings . append ( warn )
8795
88- warnings : list [SamplerWarning ] = []
8996 valid_name = [rv .name for rv in model .free_RVs + model .deterministics ]
9097 varnames = []
9198 for rv in model .free_RVs :
@@ -99,7 +106,6 @@ def run_convergence_checks(idata: arviz.InferenceData, model) -> list[SamplerWar
99106 ess = arviz .ess (idata , var_names = varnames )
100107 rhat = arviz .rhat (idata , var_names = varnames )
101108
102- warnings = []
103109 rhat_max = max (val .max () for val in rhat .values ())
104110 if rhat_max > 1.01 :
105111 msg = (
@@ -121,9 +127,6 @@ def run_convergence_checks(idata: arviz.InferenceData, model) -> list[SamplerWar
121127 warn = SamplerWarning (WarningType .CONVERGENCE , msg , "error" , extra = ess )
122128 warnings .append (warn )
123129
124- warnings += warn_divergences (idata )
125- warnings += warn_treedepth (idata )
126-
127130 return warnings
128131
129132
0 commit comments