@@ -151,6 +151,7 @@ class Metropolis(ArrayStepShared):
151151 def __init__ (
152152 self ,
153153 vars = None ,
154+ * ,
154155 S = None ,
155156 proposal_dist = None ,
156157 scaling = 1.0 ,
@@ -159,7 +160,7 @@ def __init__(
159160 model = None ,
160161 mode = None ,
161162 rng = None ,
162- ** kwargs ,
163+ blocked : bool = False ,
163164 ):
164165 """Create an instance of a Metropolis stepper.
165166
@@ -251,7 +252,7 @@ def __init__(
251252
252253 shared = pm .make_shared_replacements (initial_values , vars , model )
253254 self .delta_logp = delta_logp (initial_values , model .logp (), vars , shared )
254- super ().__init__ (vars , shared , rng = rng )
255+ super ().__init__ (vars , shared , blocked = blocked , rng = rng )
255256
256257 def reset_tuning (self ):
257258 """Reset the tuned sampler parameters to their initial values."""
@@ -418,7 +419,17 @@ class BinaryMetropolis(ArrayStep):
418419
419420 _state_class = BinaryMetropolisState
420421
421- def __init__ (self , vars , scaling = 1.0 , tune = True , tune_interval = 100 , model = None , rng = None ):
422+ def __init__ (
423+ self ,
424+ vars ,
425+ * ,
426+ scaling = 1.0 ,
427+ tune = True ,
428+ tune_interval = 100 ,
429+ model = None ,
430+ rng = None ,
431+ blocked : bool = True ,
432+ ):
422433 model = pm .modelcontext (model )
423434
424435 self .scaling = scaling
@@ -432,7 +443,7 @@ def __init__(self, vars, scaling=1.0, tune=True, tune_interval=100, model=None,
432443 if not all (v .dtype in pm .discrete_types for v in vars ):
433444 raise ValueError ("All variables must be Bernoulli for BinaryMetropolis" )
434445
435- super ().__init__ (vars , [model .compile_logp ()], rng = rng )
446+ super ().__init__ (vars , [model .compile_logp ()], blocked = blocked , rng = rng )
436447
437448 def astep (self , apoint : RaveledVars , * args ) -> tuple [RaveledVars , StatsType ]:
438449 logp = args [0 ]
@@ -530,7 +541,16 @@ class BinaryGibbsMetropolis(ArrayStep):
530541
531542 _state_class = BinaryGibbsMetropolisState
532543
533- def __init__ (self , vars , order = "random" , transit_p = 0.8 , model = None , rng = None ):
544+ def __init__ (
545+ self ,
546+ vars ,
547+ * ,
548+ order = "random" ,
549+ transit_p = 0.8 ,
550+ model = None ,
551+ rng = None ,
552+ blocked : bool = True ,
553+ ):
534554 model = pm .modelcontext (model )
535555
536556 # Doesn't actually tune, but it's required to emit a sampler stat
@@ -556,7 +576,7 @@ def __init__(self, vars, order="random", transit_p=0.8, model=None, rng=None):
556576 if not all (v .dtype in pm .discrete_types for v in vars ):
557577 raise ValueError ("All variables must be binary for BinaryGibbsMetropolis" )
558578
559- super ().__init__ (vars , [model .compile_logp ()], rng = rng )
579+ super ().__init__ (vars , [model .compile_logp ()], blocked = blocked , rng = rng )
560580
561581 def reset_tuning (self ):
562582 # There are no tuning parameters in this step method.
@@ -638,7 +658,14 @@ class CategoricalGibbsMetropolis(ArrayStep):
638658 _state_class = CategoricalGibbsMetropolisState
639659
640660 def __init__ (
641- self , vars , proposal = "uniform" , order = "random" , model = None , rng : RandomGenerator = None
661+ self ,
662+ vars ,
663+ * ,
664+ proposal = "uniform" ,
665+ order = "random" ,
666+ model = None ,
667+ rng : RandomGenerator = None ,
668+ blocked : bool = True ,
642669 ):
643670 model = pm .modelcontext (model )
644671
@@ -693,7 +720,7 @@ def __init__(
693720 # that indicates whether a draw was done in a tuning phase.
694721 self .tune = True
695722
696- super ().__init__ (vars , [model .compile_logp ()], rng = rng )
723+ super ().__init__ (vars , [model .compile_logp ()], blocked = blocked , rng = rng )
697724
698725 def reset_tuning (self ):
699726 # There are no tuning parameters in this step method.
@@ -858,6 +885,7 @@ class DEMetropolis(PopulationArrayStepShared):
858885 def __init__ (
859886 self ,
860887 vars = None ,
888+ * ,
861889 S = None ,
862890 proposal_dist = None ,
863891 lamb = None ,
@@ -867,7 +895,7 @@ def __init__(
867895 model = None ,
868896 mode = None ,
869897 rng = None ,
870- ** kwargs ,
898+ blocked : bool = True ,
871899 ):
872900 model = pm .modelcontext (model )
873901 initial_values = model .initial_point ()
@@ -902,7 +930,7 @@ def __init__(
902930
903931 shared = pm .make_shared_replacements (initial_values , vars , model )
904932 self .delta_logp = delta_logp (initial_values , model .logp (), vars , shared )
905- super ().__init__ (vars , shared , rng = rng )
933+ super ().__init__ (vars , shared , blocked = blocked , rng = rng )
906934
907935 def astep (self , q0 : RaveledVars ) -> tuple [RaveledVars , StatsType ]:
908936 point_map_info = q0 .point_map_info
@@ -1025,6 +1053,7 @@ class DEMetropolisZ(ArrayStepShared):
10251053 def __init__ (
10261054 self ,
10271055 vars = None ,
1056+ * ,
10281057 S = None ,
10291058 proposal_dist = None ,
10301059 lamb = None ,
@@ -1035,7 +1064,7 @@ def __init__(
10351064 model = None ,
10361065 mode = None ,
10371066 rng = None ,
1038- ** kwargs ,
1067+ blocked : bool = True ,
10391068 ):
10401069 model = pm .modelcontext (model )
10411070 initial_values = model .initial_point ()
@@ -1082,7 +1111,7 @@ def __init__(
10821111
10831112 shared = pm .make_shared_replacements (initial_values , vars , model )
10841113 self .delta_logp = delta_logp (initial_values , model .logp (), vars , shared )
1085- super ().__init__ (vars , shared , rng = rng )
1114+ super ().__init__ (vars , shared , blocked = blocked , rng = rng )
10861115
10871116 def reset_tuning (self ):
10881117 """Reset the tuned sampler parameters and history to their initial values."""
0 commit comments