Skip to content

Commit 737f2bf

Browse files
authored
Merge pull request #731 from reyoung/feature/fix_testing_style
Simplify the testOnePeriod method.
2 parents 82774db + ba68704 commit 737f2bf

File tree

3 files changed

+33
-47
lines changed

3 files changed

+33
-47
lines changed

paddle/api/test/run_tests.sh

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,7 @@ popd > /dev/null
2020

2121
cd $SCRIPTPATH
2222

23-
if [ ! -f ../../dist/*.whl ] ; then # Swig not compiled.
24-
exit 0
25-
fi
26-
27-
rm .test_env -rf
23+
rm -rf .test_env
2824
virtualenv .test_env
2925
source .test_env/bin/activate
3026

paddle/trainer/Tester.cpp

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,22 @@ limitations under the License. */
1717
#include <fenv.h>
1818
#include <stdio.h>
1919

20-
#include <iostream>
2120
#include <iomanip>
22-
#include <sstream>
21+
#include <iostream>
2322
#include <limits>
23+
#include <sstream>
2424

2525
#include <google/protobuf/text_format.h>
2626

27+
#include "paddle/utils/GlobalConstants.h"
2728
#include "paddle/utils/PythonUtil.h"
2829
#include "paddle/utils/Stat.h"
2930
#include "paddle/utils/Util.h"
30-
#include "paddle/utils/GlobalConstants.h"
3131

32+
#include "TesterConfig.h"
33+
#include "paddle/gserver/gradientmachines/GradientMachineMode.h"
3234
#include "paddle/gserver/gradientmachines/NeuralNetwork.h"
3335
#include "paddle/gserver/layers/ValidationLayer.h"
34-
#include "paddle/gserver/gradientmachines/GradientMachineMode.h"
35-
#include "TesterConfig.h"
3636

3737
namespace paddle {
3838

@@ -66,6 +66,9 @@ Tester::Tester(const std::shared_ptr<TrainerConfigHelper>& config,
6666
}
6767

6868
void Tester::startTestPeriod() {
69+
if (testDataProvider_) {
70+
testDataProvider_->reset();
71+
}
6972
testEvaluator_->start();
7073
testContext_.cost = 0;
7174
testContext_.numSamples = 0;
@@ -87,27 +90,18 @@ void Tester::testOneDataBatch(const DataBatch& dataBatch,
8790
void Tester::testOnePeriod() {
8891
DataBatch dataBatch;
8992
int64_t batchSize = config_->getOptConfig().batch_size();
90-
91-
int batches = std::numeric_limits<int>::max();
92-
9393
std::vector<Argument> outArgs;
94-
9594
startTestPeriod();
96-
for (int i = 0; i < batches; ++i) {
97-
int num = testDataProvider_->getNextBatch(batchSize, &dataBatch);
98-
if (num == 0) {
99-
testDataProvider_->reset();
100-
if (intconfig_->prevBatchState) {
101-
gradientMachine_->resetState();
102-
}
103-
break;
104-
}
95+
while (testDataProvider_->getNextBatch(batchSize, &dataBatch) != 0) {
10596
testOneDataBatch(dataBatch, &outArgs);
10697
}
10798
finishTestPeriod();
10899
}
109100

110101
void Tester::finishTestPeriod() {
102+
if (intconfig_->prevBatchState) {
103+
gradientMachine_->resetState();
104+
}
111105
testEvaluator_->finish();
112106
CHECK_GT(testContext_.numSamples, 0)
113107
<< "There is no samples in your test batch. Possibly "

paddle/trainer/Trainer.cpp

Lines changed: 20 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -17,36 +17,38 @@ limitations under the License. */
1717
#include <fenv.h>
1818
#include <stdio.h>
1919

20-
#include <iostream>
2120
#include <iomanip>
22-
#include <sstream>
21+
#include <iostream>
2322
#include <limits>
23+
#include <sstream>
2424

2525
#include <google/protobuf/text_format.h>
2626

27+
#include "paddle/utils/Excepts.h"
28+
#include "paddle/utils/GlobalConstants.h"
2729
#include "paddle/utils/PythonUtil.h"
2830
#include "paddle/utils/Stat.h"
2931
#include "paddle/utils/Util.h"
30-
#include "paddle/utils/Excepts.h"
31-
#include "paddle/utils/GlobalConstants.h"
3232

33-
#include "paddle/gserver/gradientmachines/NeuralNetwork.h"
34-
#include "paddle/gserver/gradientmachines/GradientMachineMode.h"
35-
#include "paddle/gserver/layers/ValidationLayer.h"
33+
#include "RemoteParameterUpdater.h"
3634
#include "TesterConfig.h"
3735
#include "ThreadParameterUpdater.h"
38-
#include "RemoteParameterUpdater.h"
3936
#include "TrainerConfigHelper.h"
37+
#include "paddle/gserver/gradientmachines/GradientMachineMode.h"
38+
#include "paddle/gserver/gradientmachines/NeuralNetwork.h"
39+
#include "paddle/gserver/layers/ValidationLayer.h"
4040

4141
P_DEFINE_string(config, "", "Trainer config file");
4242

43-
P_DEFINE_int32(test_period, 0,
43+
P_DEFINE_int32(test_period,
44+
0,
4445
"if equal 0, do test on all test data at the end of "
4546
"each pass. While if equal non-zero, do test on all test "
4647
"data every test_period batches");
47-
P_DEFINE_bool(test_all_data_in_one_period, false,
48-
"This option was deprecated, since we will always do "
49-
"test on all test set ");
48+
P_DEFINE_bool(test_all_data_in_one_period,
49+
false,
50+
"This option was deprecated, since we will always do "
51+
"test on all test set ");
5052

5153
P_DEFINE_bool(local, true, "Train in local mode or not");
5254

@@ -392,10 +394,6 @@ void Trainer::startTrain() {
392394
dataProvider_->reset();
393395
}
394396

395-
if (this->testDataProvider_) {
396-
this->testDataProvider_->reset();
397-
}
398-
399397
trainerInternal_.getGradientMachine()->start(*config_, dataProvider_);
400398
}
401399

@@ -630,16 +628,14 @@ void Trainer::test() { tester_->test(); }
630628
std::unique_ptr<TesterConfig> Trainer::createTesterConfig() {
631629
TesterConfig* conf = new TesterConfig;
632630
if (FLAGS_test_period) {
633-
LOG(WARNING)
634-
<< "The meaning of --test_period is changed: "
635-
<< "if equal 0, do test on all test data at the end of "
636-
<< "each pass. While if equal non-zero, do test on all test "
637-
<< "data every test_period batches ";
631+
LOG(WARNING) << "The meaning of --test_period is changed: "
632+
<< "if equal 0, do test on all test data at the end of "
633+
<< "each pass. While if equal non-zero, do test on all test "
634+
<< "data every test_period batches ";
638635
}
639636
if (FLAGS_test_all_data_in_one_period) {
640-
LOG(WARNING)
641-
<< "--test_all_data_in_one_period was deprecated, since "
642-
<< "we will always do test on all test set ";
637+
LOG(WARNING) << "--test_all_data_in_one_period was deprecated, since "
638+
<< "we will always do test on all test set ";
643639
}
644640
conf->testPeriod = FLAGS_test_period;
645641
conf->prevBatchState = FLAGS_prev_batch_state;

0 commit comments

Comments
 (0)