1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15-
1615import collections
1716import sys
1817from typing import Optional
1918
2019import arviz as az
2120import blackjax
2221import jax
23- import jax .numpy as jnp
24- import jax .random as random
2522import numpy as np
2623import pymc as pm
27- from pymc import modelcontext
24+ from packaging import version
25+ from pymc .backends .arviz import coords_and_dims_for_inferencedata
26+ from pymc .blocking import DictToArrayBijection , RaveledVars
27+ from pymc .model import modelcontext
2828from pymc .sampling .jax import get_jaxified_graph
2929from pymc .util import RandomSeed , _get_seeds_per_chain , get_default_varnames
3030
3131
3232def convert_flat_trace_to_idata (
3333 samples ,
34- dims = None ,
35- coords = None ,
3634 include_transformed = False ,
3735 postprocessing_backend = "cpu" ,
3836 model = None ,
3937):
4038
4139 model = modelcontext (model )
42- init_position_dict = model .initial_point ()
40+ ip = model .initial_point ()
41+ ip_point_map_info = pm .blocking .DictToArrayBijection .map (ip ).point_map_info
4342 trace = collections .defaultdict (list )
44- astart = pm .blocking .DictToArrayBijection .map (init_position_dict )
4543 for sample in samples :
46- raveld_vars = pm . blocking . RaveledVars (sample , astart . point_map_info )
47- point = pm . blocking . DictToArrayBijection .rmap (raveld_vars , init_position_dict )
44+ raveld_vars = RaveledVars (sample , ip_point_map_info )
45+ point = DictToArrayBijection .rmap (raveld_vars , ip )
4846 for p , v in point .items ():
4947 trace [p ].append (v .tolist ())
5048
@@ -57,19 +55,19 @@ def convert_flat_trace_to_idata(
5755 result = jax .vmap (jax .vmap (jax_fn ))(
5856 * jax .device_put (list (trace .values ()), jax .devices (postprocessing_backend )[0 ])
5957 )
60-
6158 trace = {v .name : r for v , r in zip (vars_to_sample , result )}
59+ coords , dims = coords_and_dims_for_inferencedata (model )
6260 idata = az .from_dict (trace , dims = dims , coords = coords )
6361
6462 return idata
6563
6664
6765def fit_pathfinder (
68- iterations = 5_000 ,
66+ samples = 1000 ,
6967 random_seed : Optional [RandomSeed ] = None ,
7068 postprocessing_backend = "cpu" ,
71- ftol = 1e-4 ,
7269 model = None ,
70+ ** pathfinder_kwargs ,
7371):
7472 """
7573 Fit the pathfinder algorithm as implemented in blackjax
@@ -78,15 +76,15 @@ def fit_pathfinder(
7876
7977 Parameters
8078 ----------
81- iterations : int
82- Number of iterations to run .
79+ samples : int
80+ Number of samples to draw from the fitted approximation .
8381 random_seed : int
8482 Random seed to set.
8583 postprocessing_backend : str
8684 Where to compute transformations of the trace.
8785 "cpu" or "gpu".
88- ftol : float
89- Floating point tolerance
86+ pathfinder_kwargs:
87+ kwargs for blackjax.vi.pathfinder.approximate
9088
9189 Returns
9290 -------
@@ -96,53 +94,42 @@ def fit_pathfinder(
9694 ---------
9795 https://arxiv.org/abs/2108.03782
9896 """
99-
100- (random_seed ,) = _get_seeds_per_chain (random_seed , 1 )
97+ # Temporarily helper
98+ if version .parse (blackjax .__version__ ).major < 1 :
99+ raise ImportError ("fit_pathfinder requires blackjax 1.0 or above" )
101100
102101 model = modelcontext (model )
103102
104- rvs = [rv .name for rv in model .value_vars ]
105- init_position_dict = model .initial_point ()
106- init_position = [init_position_dict [rv ] for rv in rvs ]
103+ ip = model .initial_point ()
104+ ip_map = DictToArrayBijection .map (ip )
107105
108106 new_logprob , new_input = pm .pytensorf .join_nonshared_inputs (
109- init_position_dict , (model .logp (),), model .value_vars , ()
107+ ip , (model .logp (),), model .value_vars , ()
110108 )
111109
112110 logprob_fn_list = get_jaxified_graph ([new_input ], new_logprob )
113111
114112 def logprob_fn (x ):
115113 return logprob_fn_list (x )[0 ]
116114
117- dim = sum (v .size for v in init_position_dict .values ())
118-
119- rng_key = random .PRNGKey (random_seed )
120- w0 = random .multivariate_normal (rng_key , 2.0 + jnp .zeros (dim ), jnp .eye (dim ))
121- path = blackjax .vi .pathfinder .init (rng_key , logprob_fn , w0 , return_path = True , ftol = ftol )
122-
123- pathfinder = blackjax .kernels .pathfinder (rng_key , logprob_fn , ftol = ftol )
124- state = pathfinder .init (w0 )
125-
126- def inference_loop (rng_key , kernel , initial_state , num_samples ):
127- @jax .jit
128- def one_step (state , rng_key ):
129- state , info = kernel (rng_key , state )
130- return state , (state , info )
115+ [pathfinder_seed , sample_seed ] = _get_seeds_per_chain (random_seed , 2 )
131116
132- keys = jax .random .split (rng_key , num_samples )
133- return jax .lax .scan (one_step , initial_state , keys )
134-
135- _ , rng_key = random .split (rng_key )
136117 print ("Running pathfinder..." , file = sys .stdout )
137- _ , (_ , samples ) = inference_loop (rng_key , pathfinder .step , state , iterations )
138-
139- dims = {
140- var_name : [dim for dim in dims if dim is not None ]
141- for var_name , dims in model .named_vars_to_dims .items ()
142- }
118+ pathfinder_state , _ = blackjax .vi .pathfinder .approximate (
119+ rng_key = jax .random .key (pathfinder_seed ),
120+ logdensity_fn = logprob_fn ,
121+ initial_position = ip_map .data ,
122+ ** pathfinder_kwargs ,
123+ )
124+ samples , _ = blackjax .vi .pathfinder .sample (
125+ rng_key = jax .random .key (sample_seed ),
126+ state = pathfinder_state ,
127+ num_samples = samples ,
128+ )
143129
144130 idata = convert_flat_trace_to_idata (
145- samples , postprocessing_backend = postprocessing_backend , coords = model .coords , dims = dims
131+ samples ,
132+ postprocessing_backend = postprocessing_backend ,
133+ model = model ,
146134 )
147-
148135 return idata
0 commit comments