Skip to content

Commit aaecfcc

Browse files
committed
Support predicting the samples from sys.stdin
1 parent db37981 commit aaecfcc

File tree

2 files changed

+47
-38
lines changed

2 files changed

+47
-38
lines changed

demo/sentiment/predict.py

Lines changed: 41 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import os
15+
import os, sys
1616
import numpy as np
1717
from optparse import OptionParser
1818
from py_paddle import swig_paddle, DataProviderConverter
@@ -66,35 +66,42 @@ def load_label(self, label_file):
6666
for v in open(label_file, 'r'):
6767
self.label[int(v.split('\t')[1])] = v.split('\t')[0]
6868

69-
def get_data(self, data_file):
69+
def get_data(self, data):
7070
"""
7171
Get input data of paddle format.
7272
"""
73-
with open(data_file, 'r') as fdata:
74-
for line in fdata:
75-
words = line.strip().split()
76-
word_slot = [
77-
self.word_dict[w] for w in words if w in self.word_dict
78-
]
79-
if not word_slot:
80-
print "all words are not in dictionary: %s", line
81-
continue
82-
yield [word_slot]
83-
84-
def predict(self, data_file):
85-
"""
86-
data_file: file name of input data.
87-
"""
88-
input = self.converter(self.get_data(data_file))
89-
output = self.network.forwardTest(input)
90-
prob = output[0]["value"]
91-
lab = np.argsort(-prob)
92-
if self.label is None:
93-
print("%s: predicting label is %d" % (data_file, lab[0][0]))
94-
else:
95-
print("%s: predicting label is %s" %
96-
(data_file, self.label[lab[0][0]]))
73+
for line in data:
74+
words = line.strip().split()
75+
word_slot = [
76+
self.word_dict[w] for w in words if w in self.word_dict
77+
]
78+
if not word_slot:
79+
print "all words are not in dictionary: %s", line
80+
continue
81+
yield [word_slot]
82+
83+
def predict(self, batch_size):
84+
85+
def batch_predict(batch_data):
86+
input = self.converter(self.get_data(batch_data))
87+
output = self.network.forwardTest(input)
88+
prob = output[0]["value"]
89+
labs = np.argsort(-prob)
90+
for idx, lab in enumerate(labs):
91+
if self.label is None:
92+
print("predicting label is %d" % (lab[0]))
93+
else:
94+
print("predicting label is %s" %
95+
(self.label[lab[0]]))
9796

97+
batch = []
98+
for line in sys.stdin:
99+
batch.append(line)
100+
if len(batch) == batch_size:
101+
batch_predict(batch)
102+
batch=[]
103+
if len(batch) > 0:
104+
batch_predict(batch)
98105

99106
def option_parser():
100107
usage = "python predict.py -n config -w model_dir -d dictionary -i input_file "
@@ -119,11 +126,13 @@ def option_parser():
119126
default=None,
120127
help="dictionary file")
121128
parser.add_option(
122-
"-i",
123-
"--data",
129+
"-c",
130+
"--batch_size",
131+
type="int",
124132
action="store",
125-
dest="data",
126-
help="data file to predict")
133+
dest="batch_size",
134+
default=1,
135+
help="the batch size for prediction")
127136
parser.add_option(
128137
"-w",
129138
"--model",
@@ -137,13 +146,13 @@ def option_parser():
137146
def main():
138147
options, args = option_parser()
139148
train_conf = options.train_conf
140-
data = options.data
149+
batch_size = options.batch_size
141150
dict_file = options.dict_file
142151
model_path = options.model_path
143152
label = options.label
144153
swig_paddle.initPaddle("--use_gpu=0")
145154
predict = SentimentPrediction(train_conf, dict_file, model_path, label)
146-
predict.predict(data)
155+
predict.predict(batch_size)
147156

148157

149158
if __name__ == '__main__':

demo/sentiment/predict.sh

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ set -e
1919
model=model_output/pass-00002/
2020
config=trainer_config.py
2121
label=data/pre-imdb/labels.list
22-
python predict.py \
23-
-n $config\
24-
-w $model \
25-
-b $label \
26-
-d ./data/pre-imdb/dict.txt \
27-
-i ./data/aclImdb/test/pos/10007_10.txt
22+
cat ./data/aclImdb/test/pos/10007_10.txt | python predict.py \
23+
--tconf=$config\
24+
--model=$model \
25+
--label=$label \
26+
--dict=./data/pre-imdb/dict.txt \
27+
--batch_size=1

0 commit comments

Comments
 (0)