1313# limitations under the License.
1414
1515import hashlib
16+ import json
1617import sys
1718import tempfile
1819from typing import Dict
@@ -43,29 +44,35 @@ def toy_y(toy_X):
4344@pytest .fixture (scope = "module" )
4445def fitted_model_instance (toy_X , toy_y ):
4546 sampler_config = {
46- "draws" : 500 ,
47- "tune" : 300 ,
47+ "draws" : 100 ,
48+ "tune" : 100 ,
4849 "chains" : 2 ,
4950 "target_accept" : 0.95 ,
5051 }
5152 model_config = {
52- "a" : {"loc" : 0 , "scale" : 10 },
53+ "a" : {"loc" : 0 , "scale" : 10 , "dims" : ( "numbers" ,) },
5354 "b" : {"loc" : 0 , "scale" : 10 },
5455 "obs_error" : 2 ,
5556 }
56- model = test_ModelBuilder (model_config = model_config , sampler_config = sampler_config )
57+ model = test_ModelBuilder (
58+ model_config = model_config , sampler_config = sampler_config , test_parameter = "test_paramter"
59+ )
5760 model .fit (toy_X )
5861 return model
5962
6063
6164class test_ModelBuilder (ModelBuilder ):
65+ def __init__ (self , model_config = None , sampler_config = None , test_parameter = None ):
66+ self .test_parameter = test_parameter
67+ super ().__init__ (model_config = model_config , sampler_config = sampler_config )
6268
63- _model_type = "LinearModel "
69+ _model_type = "test_model "
6470 version = "0.1"
6571
6672 def build_model (self , X : pd .DataFrame , y : pd .Series , model_config = None ):
73+ coords = {"numbers" : np .arange (len (X ))}
6774 self .generate_and_preprocess_model_data (X , y )
68- with pm .Model () as self .model :
75+ with pm .Model (coords = coords ) as self .model :
6976 if model_config is None :
7077 model_config = self .default_model_config
7178 x = pm .MutableData ("x" , self .X ["input" ].values )
@@ -79,13 +86,16 @@ def build_model(self, X: pd.DataFrame, y: pd.Series, model_config=None):
7986 obs_error = model_config ["obs_error" ]
8087
8188 # priors
82- a = pm .Normal ("a" , a_loc , sigma = a_scale )
89+ a = pm .Normal ("a" , a_loc , sigma = a_scale , dims = model_config [ "a" ][ "dims" ] )
8390 b = pm .Normal ("b" , b_loc , sigma = b_scale )
8491 obs_error = pm .HalfNormal ("σ_model_fmc" , obs_error )
8592
8693 # observed data
8794 output = pm .Normal ("output" , a + b * x , obs_error , shape = x .shape , observed = y_data )
8895
96+ def _save_input_params (self , idata ):
97+ idata .attrs ["test_paramter" ] = json .dumps (self .test_parameter )
98+
8999 @property
90100 def output_var (self ):
91101 return "output"
@@ -107,7 +117,7 @@ def generate_and_preprocess_model_data(self, X: pd.DataFrame, y: pd.Series):
107117 @property
108118 def default_model_config (self ) -> Dict :
109119 return {
110- "a" : {"loc" : 0 , "scale" : 10 },
120+ "a" : {"loc" : 0 , "scale" : 10 , "dims" : ( "numbers" ,) },
111121 "b" : {"loc" : 0 , "scale" : 10 },
112122 "obs_error" : 2 ,
113123 }
@@ -122,6 +132,38 @@ def default_sampler_config(self) -> Dict:
122132 }
123133
124134
135+ def test_save_input_params (fitted_model_instance ):
136+ assert fitted_model_instance .idata .attrs ["test_paramter" ] == '"test_paramter"'
137+
138+
139+ def test_save_load (fitted_model_instance ):
140+ temp = tempfile .NamedTemporaryFile (mode = "w" , encoding = "utf-8" , delete = False )
141+ fitted_model_instance .save (temp .name )
142+ test_builder2 = test_ModelBuilder .load (temp .name )
143+ assert fitted_model_instance .idata .groups () == test_builder2 .idata .groups ()
144+ assert fitted_model_instance .id == test_builder2 .id
145+ x_pred = np .random .uniform (low = 0 , high = 1 , size = 100 )
146+ prediction_data = pd .DataFrame ({"input" : x_pred })
147+ pred1 = fitted_model_instance .predict (prediction_data ["input" ])
148+ pred2 = test_builder2 .predict (prediction_data ["input" ])
149+ assert pred1 .shape == pred2 .shape
150+ temp .close ()
151+
152+
153+ def test_convert_dims_to_tuple (fitted_model_instance ):
154+ model_config = {
155+ "a" : {
156+ "loc" : 0 ,
157+ "scale" : 10 ,
158+ "dims" : [
159+ "x" ,
160+ ],
161+ },
162+ }
163+ converted_model_config = fitted_model_instance ._convert_dims_to_tuple (model_config )
164+ assert converted_model_config ["a" ]["dims" ] == ("x" ,)
165+
166+
125167def test_initial_build_and_fit (fitted_model_instance , check_idata = True ) -> ModelBuilder :
126168 if check_idata :
127169 assert fitted_model_instance .idata is not None
@@ -162,20 +204,6 @@ def test_fit_no_y(toy_X):
162204@pytest .mark .skipif (
163205 sys .platform == "win32" , reason = "Permissions for temp files not granted on windows CI."
164206)
165- def test_save_load (fitted_model_instance ):
166- temp = tempfile .NamedTemporaryFile (mode = "w" , encoding = "utf-8" , delete = False )
167- fitted_model_instance .save (temp .name )
168- test_builder2 = test_ModelBuilder .load (temp .name )
169- assert fitted_model_instance .idata .groups () == test_builder2 .idata .groups ()
170-
171- x_pred = np .random .uniform (low = 0 , high = 1 , size = 100 )
172- prediction_data = pd .DataFrame ({"input" : x_pred })
173- pred1 = fitted_model_instance .predict (prediction_data ["input" ])
174- pred2 = test_builder2 .predict (prediction_data ["input" ])
175- assert pred1 .shape == pred2 .shape
176- temp .close ()
177-
178-
179207def test_predict (fitted_model_instance ):
180208 x_pred = np .random .uniform (low = 0 , high = 1 , size = 100 )
181209 prediction_data = pd .DataFrame ({"input" : x_pred })
0 commit comments