@@ -71,15 +71,13 @@ def build(self, model, target="all"):
7171 raise Exception ("Currently untested on non-Linux OS" )
7272
7373 def _numpy_to_dat (self , model , x ):
74- if len (self .get_input_variables ()) != 1 :
75- raise Exception ("Currently unsupported for multi-input projects" )
74+ if len (model .get_input_variables ()) != 1 :
75+ raise Exception ("Currently unsupported for multi-input/output projects" )
7676
7777 # Verify numpy array of correct shape
78- expected_shape = (np .newaxis , model .get_input_variables ()[0 ].size ())
79- print (f"Expected model input shape: { expected_shape } " )
80- print (f"Give numpy array shape: { x .shape } " )
81- if expected_shape != x .shape :
82- raise Exception (f'Input shape mismatch, got { x .shape } , expected { expected_shape } ' )
78+ expected_shape = model .get_input_variables ()[0 ].size ()
79+ if expected_shape != x .shape [- 1 ]:
80+ raise Exception (f'Input shape mismatch, got { x .shape } , expected (_, { expected_shape } )' )
8381
8482 # Write to tb_data/tb_input_features.dat
8583 input_dat = open (f'{ model .config .get_output_dir ()} /tb_data/tb_input_features.dat' , 'w' )
@@ -90,16 +88,8 @@ def _numpy_to_dat(self, model, x):
9088
9189 def _dat_to_numpy (self , model ):
9290 expected_shape = model .get_output_variables ()[0 ].size ()
93- y = np .array ([], dtype = float ).reshape (0 , expected_shape )
94-
95- output_dat = open (f'{ model .config .get_output_dir ()} /tb_data/hw_results.dat' , 'r' )
96- for line in output_dat .readlines ():
97- data = [list (map (float , line .strip ().split ()))]
98- if len (data ) != expected_shape :
99- raise Exception ('Error in output file. Does not match expected model output shape.' )
100- y = np .concatenate (y , np .array (data )[np .newaxis , :], axis = 0 )
101- output_dat .close ()
102-
91+ output_file = f'{ model .config .get_output_dir ()} /tb_data/hw_results.dat'
92+ y = np .loadtxt (output_file , dtype = float ).reshape (- 1 , expected_shape )
10393 return y
10494
10595 def hardware_predict (self , model , x ):
0 commit comments