1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- import os
15+ import os , sys
1616import numpy as np
1717from optparse import OptionParser
1818from py_paddle import swig_paddle , DataProviderConverter
@@ -66,35 +66,42 @@ def load_label(self, label_file):
6666 for v in open (label_file , 'r' ):
6767 self .label [int (v .split ('\t ' )[1 ])] = v .split ('\t ' )[0 ]
6868
69- def get_data (self , data_file ):
69+ def get_data (self , data ):
7070 """
7171 Get input data of paddle format.
7272 """
73- with open (data_file , 'r' ) as fdata :
74- for line in fdata :
75- words = line .strip ().split ()
76- word_slot = [
77- self .word_dict [w ] for w in words if w in self .word_dict
78- ]
79- if not word_slot :
80- print "all words are not in dictionary: %s" , line
81- continue
82- yield [word_slot ]
83-
84- def predict (self , data_file ):
85- """
86- data_file: file name of input data.
87- """
88- input = self .converter (self .get_data (data_file ))
89- output = self .network .forwardTest (input )
90- prob = output [0 ]["value" ]
91- lab = np .argsort (- prob )
92- if self .label is None :
93- print ("%s: predicting label is %d" % (data_file , lab [0 ][0 ]))
94- else :
95- print ("%s: predicting label is %s" %
96- (data_file , self .label [lab [0 ][0 ]]))
73+ for line in data :
74+ words = line .strip ().split ()
75+ word_slot = [
76+ self .word_dict [w ] for w in words if w in self .word_dict
77+ ]
78+ if not word_slot :
79+ print "all words are not in dictionary: %s" , line
80+ continue
81+ yield [word_slot ]
82+
83+ def predict (self , batch_size ):
84+
85+ def batch_predict (batch_data ):
86+ input = self .converter (self .get_data (batch_data ))
87+ output = self .network .forwardTest (input )
88+ prob = output [0 ]["value" ]
89+ labs = np .argsort (- prob )
90+ for idx , lab in enumerate (labs ):
91+ if self .label is None :
92+ print ("predicting label is %d" % (lab [0 ]))
93+ else :
94+ print ("predicting label is %s" %
95+ (self .label [lab [0 ]]))
9796
97+ batch = []
98+ for line in sys .stdin :
99+ batch .append (line )
100+ if len (batch ) == batch_size :
101+ batch_predict (batch )
102+ batch = []
103+ if len (batch ) > 0 :
104+ batch_predict (batch )
98105
99106def option_parser ():
100107 usage = "python predict.py -n config -w model_dir -d dictionary -i input_file "
@@ -119,11 +126,13 @@ def option_parser():
119126 default = None ,
120127 help = "dictionary file" )
121128 parser .add_option (
122- "-i" ,
123- "--data" ,
129+ "-c" ,
130+ "--batch_size" ,
131+ type = "int" ,
124132 action = "store" ,
125- dest = "data" ,
126- help = "data file to predict" )
133+ dest = "batch_size" ,
134+ default = 1 ,
135+ help = "the batch size for prediction" )
127136 parser .add_option (
128137 "-w" ,
129138 "--model" ,
@@ -137,13 +146,13 @@ def option_parser():
137146def main ():
138147 options , args = option_parser ()
139148 train_conf = options .train_conf
140- data = options .data
149+ batch_size = options .batch_size
141150 dict_file = options .dict_file
142151 model_path = options .model_path
143152 label = options .label
144153 swig_paddle .initPaddle ("--use_gpu=0" )
145154 predict = SentimentPrediction (train_conf , dict_file , model_path , label )
146- predict .predict (data )
155+ predict .predict (batch_size )
147156
148157
149158if __name__ == '__main__' :
0 commit comments