Skip to content

Commit 239cade

Browse files
authored
Merge pull request #411 from backyes/bugfix_test_period
Re-design command options for testing for better understanding
2 parents 9ffa434 + 1f2423a commit 239cade

File tree

6 files changed

+29
-38
lines changed

6 files changed

+29
-38
lines changed

doc/howto/cmd_parameter/arguments.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ It looks like there are a lot of arguments. However, most of them are for develo
143143
</tr>
144144

145145
<tr>
146-
<td class="left" rowspan = "2">testing during training</td><td class="left">test_all_data_in_one_period</td>
146+
<td class="left" rowspan = "2">testing during training</td><td class="left">test_period</td>
147147
<td class="left">√</td><td class="left">√</td><td class="left"></td><td class="left"></td>
148148
</tr>
149149

doc/howto/cmd_parameter/detail_introduction.md

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
- type: string (default: null).
3232

3333
* `--version`
34-
- Whether to print version infomatrion.
34+
- Whether to print version information.
3535
- type: bool (default: 0).
3636

3737
* `--show_layer_stat`
@@ -110,8 +110,8 @@
110110
- type: int32 (default: -1).
111111

112112
* `--test_period`
113-
- Run testing every test_period train batches. If not set, run testing each pass.
114-
- type: int32 (default: 1000).
113+
- if equal 0, do test on all test data at the end of each pass. While if equal non-zero, do test on all test data every test_period batches.
114+
- type: int32 (default: 0).
115115

116116
* `--test_wait`
117117
- Whether to wait for parameter per pass if not exist. If set test_data_path in submitting environment of cluster, it will launch one process to perfom testing, so we need to set test_wait=1. Note that in the cluster submitting environment, this argument has been set True by default.
@@ -121,10 +121,6 @@
121121
- File that saves the model list when testing. It was set automatically when using cluster submitting environment after setting model_path.
122122
- type: string (default: "", null).
123123

124-
* `--test_all_data_in_one_period`
125-
- This argument is usually used in testing period during traning. If true, all data will be tested in one test period. Otherwise (batch_size * log_peroid) data will be tested.
126-
- type: bool (default: 0).
127-
128124
* `--predict_output_dir`
129125
- Directory that saves the layer output. It is configured in Outputs() in network config. Default, this argument is null, meaning save nothing. Specify this directory if you want to save feature map of some layers in testing mode. Note that, layer outputs are values after activation function.
130126
- type: string (default: "", null).

doc/howto/cmd_parameter/use_case.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,8 @@ paddle train \
1010
--config=network_config \
1111
--save_dir=output \
1212
--trainer_count=COUNT \ #(default:1)
13-
--test_period=M \ #(default:1000)
14-
--test_all_data_in_one_period=true \ #(default:false)
15-
--num_passes=N \ #(defalut:100)
13+
--test_period=M \ #(default:0)
14+
--num_passes=N \ #(defalut:100)
1615
--log_period=K \ #(default:100)
1716
--dot_period=1000 \ #(default:1)
1817
#[--show_parameter_stats_period=100] \ #(default:0)

paddle/trainer/Tester.cpp

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,8 @@ void Tester::testOneDataBatch(const DataBatch& dataBatch,
8787
void Tester::testOnePeriod() {
8888
DataBatch dataBatch;
8989
int64_t batchSize = config_->getOptConfig().batch_size();
90-
bool testAllData =
91-
intconfig_->testPeriod == 0 || intconfig_->testAllDataInOnePeriod;
92-
int batches =
93-
testAllData ? std::numeric_limits<int>::max() : intconfig_->testPeriod;
90+
91+
int batches = std::numeric_limits<int>::max();
9492

9593
std::vector<Argument> outArgs;
9694

@@ -102,11 +100,7 @@ void Tester::testOnePeriod() {
102100
if (intconfig_->prevBatchState) {
103101
gradientMachine_->resetState();
104102
}
105-
if (testAllData) {
106-
break;
107-
} else {
108-
num = testDataProvider_->getNextBatch(batchSize, &dataBatch);
109-
}
103+
break;
110104
}
111105
testOneDataBatch(dataBatch, &outArgs);
112106
}

paddle/trainer/TesterConfig.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,6 @@ struct TesterConfig {
3939
*/
4040
int testPeriod;
4141

42-
/**
43-
* indicate whether testing data in one period
44-
*/
45-
bool testAllDataInOnePeriod;
46-
4742
/**
4843
* indicate whether to save previous batch state
4944
*/

paddle/trainer/Trainer.cpp

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -39,20 +39,16 @@ limitations under the License. */
3939
#include "TrainerConfigHelper.h"
4040

4141
P_DEFINE_string(config, "", "Trainer config file");
42-
P_DEFINE_int32(test_period,
43-
0,
44-
"Run test every so many train batches."
45-
" 0 for testing after each pass."
46-
" If not 0, test log_period batches."
47-
" If 0, test on all test data");
4842

49-
P_DEFINE_bool(local, true, "Train in local mode or not");
43+
P_DEFINE_int32(test_period, 0,
44+
"if equal 0, do test on all test data at the end of "
45+
"each pass. While if equal non-zero, do test on all test "
46+
"data every test_period batches");
47+
P_DEFINE_bool(test_all_data_in_one_period, false,
48+
"This option was deprecated, since we will always do "
49+
"test on all test set ");
5050

51-
P_DEFINE_bool(
52-
test_all_data_in_one_period,
53-
false,
54-
"true will test all data in one test peroid."
55-
"Otherwise test (batch_size * log_peroid) data in one test period.");
51+
P_DEFINE_bool(local, true, "Train in local mode or not");
5652

5753
P_DEFINE_int32(average_test_period,
5854
0,
@@ -633,8 +629,19 @@ void Trainer::test() { tester_->test(); }
633629

634630
std::unique_ptr<TesterConfig> Trainer::createTesterConfig() {
635631
TesterConfig* conf = new TesterConfig;
632+
if (FLAGS_test_period) {
633+
LOG(WARNING)
634+
<< "The meaning of --test_period is changed: "
635+
<< "if equal 0, do test on all test data at the end of "
636+
<< "each pass. While if equal non-zero, do test on all test "
637+
<< "data every test_period batches ";
638+
}
639+
if (FLAGS_test_all_data_in_one_period) {
640+
LOG(WARNING)
641+
<< "--test_all_data_in_one_period was deprecated, since "
642+
<< "we will always do test on all test set ";
643+
}
636644
conf->testPeriod = FLAGS_test_period;
637-
conf->testAllDataInOnePeriod = FLAGS_test_all_data_in_one_period;
638645
conf->prevBatchState = FLAGS_prev_batch_state;
639646
conf->logPeriod = FLAGS_log_period;
640647
conf->loadsaveParametersInPserver = FLAGS_loadsave_parameters_in_pserver;

0 commit comments

Comments
 (0)