Skip to content

Commit 5c2206c

Browse files
authored
Merge pull request #921 from reyoung/feature/refine_demo_dataprovider
Feature/refine demo dataprovider
2 parents 6577782 + ee2c142 commit 5c2206c

File tree

18 files changed

+74
-49
lines changed

18 files changed

+74
-49
lines changed

demo/gan/data/download_cifar.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#!/bin/bash
12
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
23
#
34
# Licensed under the Apache License, Version 2.0 (the "License");

demo/gan/data/get_mnist_data.sh

100644100755
File mode changed.

demo/image_classification/data/download_cifar.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#!/bin/bash
12
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
23
#
34
# Licensed under the Apache License, Version 2.0 (the "License");

demo/image_classification/image_provider.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
#
2323
# {'img_size': 32,
24-
# 'settings': <paddle.trainer.PyDataProviderWrapper.Cls instance at 0x7fea27cb6050>,
24+
# 'settings': a global object,
2525
# 'color': True,
2626
# 'mean_img_size': 32,
2727
# 'meta': './data/cifar-out/batches/batches.meta',
@@ -50,10 +50,10 @@ def hook(settings, img_size, mean_img_size, num_classes, color, meta, use_jpeg,
5050

5151
settings.logger.info('Image size: %s', settings.img_size)
5252
settings.logger.info('Meta path: %s', settings.meta_path)
53-
settings.input_types = [
54-
dense_vector(settings.img_raw_size), # image feature
55-
integer_value(settings.num_classes)
56-
] # labels
53+
settings.input_types = {
54+
'image': dense_vector(settings.img_raw_size),
55+
'label': integer_value(settings.num_classes)
56+
}
5757

5858
settings.logger.info('DataProvider Initialization finished')
5959

@@ -83,4 +83,7 @@ def processData(settings, file_list):
8383
img, settings.img_mean, settings.img_size,
8484
settings.is_train, settings.color)
8585
label = data['labels'][i]
86-
yield img_feat.astype('float32'), int(label)
86+
yield {
87+
'image': img_feat.astype('float32'),
88+
'label': int(label)
89+
}

demo/introduction/.gitignore

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
dataprovider.pyc
2+
empty.list
3+
train.log
4+
output
5+
train.list

demo/introduction/dataprovider.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717

1818

1919
# define data types of input: 2 real numbers
20-
@provider(input_types=[dense_vector(1), dense_vector(1)], use_seq=False)
20+
@provider(
21+
input_types={'x': dense_vector(1),
22+
'y': dense_vector(1)}, use_seq=False)
2123
def process(settings, input_file):
2224
for i in xrange(2000):
2325
x = random.random()
24-
yield [x], [2 * x + 0.3]
26+
yield {'x': [x], 'y': [2 * x + 0.3]}

demo/introduction/trainer_config.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,8 @@
1515
from paddle.trainer_config_helpers import *
1616

1717
# 1. read data. Suppose you saved above python code as dataprovider.py
18-
data_file = 'empty.list'
19-
with open(data_file, 'w') as f:
20-
f.writelines(' ')
2118
define_py_data_sources2(
22-
train_list=data_file,
19+
train_list=['no_matter.txt'],
2320
test_list=None,
2421
module='dataprovider',
2522
obj='process',

demo/quick_start/.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ data/test.list
88
data/test.txt
99
data/train.list
1010
data/train.txt
11+
data/pred.list
12+
data/pred.txt
1113
dataprovider_copy_1.py
1214
train.log
1315
output

demo/quick_start/dataprovider_bow.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,16 +31,16 @@ def initializer(settings, dictionary, **kwargs):
3131

3232
# setting.input_types specifies what the data types the data provider
3333
# generates.
34-
settings.input_types = [
34+
settings.input_types = {
3535
# The first input is a sparse_binary_vector,
3636
# which means each dimension of the vector is either 0 or 1. It is the
3737
# bag-of-words (BOW) representation of the texts.
38-
sparse_binary_vector(len(dictionary)),
38+
'word': sparse_binary_vector(len(dictionary)),
3939
# The second input is an integer. It represents the category id of the
4040
# sample. 2 means there are two labels in the dataset.
4141
# (1 for positive and 0 for negative)
42-
integer_value(2)
43-
]
42+
'label': integer_value(2)
43+
}
4444

4545

4646
# Delaring a data provider. It has an initializer 'data_initialzer'.
@@ -67,12 +67,12 @@ def process(settings, file_name):
6767
# Return the features for the current comment. The first is a list
6868
# of ids representing a 0-1 binary sparse vector of the text,
6969
# the second is the integer id of the label.
70-
yield word_vector, int(label)
70+
yield {'word': word_vector, 'label': int(label)}
7171

7272

7373
def predict_initializer(settings, dictionary, **kwargs):
7474
settings.word_dict = dictionary
75-
settings.input_types = [sparse_binary_vector(len(dictionary))]
75+
settings.input_types = {'word': sparse_binary_vector(len(dictionary))}
7676

7777

7878
# Declaring a data provider for prediction. The difference with process
@@ -83,4 +83,4 @@ def process_predict(settings, file_name):
8383
for line in f:
8484
comment = line.strip().split()
8585
word_vector = [settings.word_dict.get(w, UNK_IDX) for w in comment]
86-
yield word_vector
86+
yield {'word': word_vector}

demo/quick_start/dataprovider_emb.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@
1919

2020
def initializer(settings, dictionary, **kwargs):
2121
settings.word_dict = dictionary
22-
settings.input_types = [
22+
settings.input_types = {
2323
# Define the type of the first input as sequence of integer.
2424
# The value of the integers range from 0 to len(dictrionary)-1
25-
integer_value_sequence(len(dictionary)),
25+
'word': integer_value_sequence(len(dictionary)),
2626
# Define the second input for label id
27-
integer_value(2)
28-
]
27+
'label': integer_value(2)
28+
}
2929

3030

3131
@provider(init_hook=initializer, cache=CacheType.CACHE_PASS_IN_MEM)
@@ -35,15 +35,12 @@ def process(settings, file_name):
3535
label, comment = line.strip().split('\t')
3636
words = comment.split()
3737
word_slot = [settings.word_dict.get(w, UNK_IDX) for w in words]
38-
yield word_slot, int(label)
38+
yield {'word': word_slot, 'label': int(label)}
3939

4040

4141
def predict_initializer(settings, dictionary, **kwargs):
4242
settings.word_dict = dictionary
43-
settings.input_types = [
44-
integer_value(
45-
len(dictionary), seq_type=SequenceType.SEQUENCE)
46-
]
43+
settings.input_types = {'word': integer_value_sequence(len(dictionary))}
4744

4845

4946
@provider(init_hook=predict_initializer, should_shuffle=False)
@@ -52,4 +49,4 @@ def process_predict(settings, file_name):
5249
for line in f:
5350
comment = line.strip().split()
5451
word_slot = [settings.word_dict.get(w, UNK_IDX) for w in comment]
55-
yield word_slot
52+
yield {'word': word_slot}

0 commit comments

Comments
 (0)