Skip to content

Commit e823c95

Browse files
authored
Merge pull request #947 from reyoung/feature/clean_bn_code
Clean BatchNorm Code.
2 parents 8e25fbb + af5d954 commit e823c95

File tree

2 files changed

+10
-20
lines changed

2 files changed

+10
-20
lines changed

paddle/gserver/layers/BatchNormalizationLayer.cpp

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -59,24 +59,14 @@ void BatchNormalizationLayer::calMeanAndStd(const MatrixPtr& mat) {
5959

6060
void BatchNormalizationLayer::calMovingMeanAndVar() {
6161
// calculating and saving moving mean and variance
62-
MatrixPtr movingMean = movingMean_->getW();
63-
MatrixPtr movingVar = movingVar_->getW();
64-
65-
if (!useGpu_ && FLAGS_trainer_count > 1) {
66-
auto mvMean = std::dynamic_pointer_cast<SharedCpuMatrix>(movingMean);
67-
auto mvVar = std::dynamic_pointer_cast<SharedCpuMatrix>(movingVar);
68-
CHECK(mvMean && mvVar);
69-
70-
mvMean->add(*savedMean_, movingAvgFraction_, 1.0 - movingAvgFraction_);
71-
mvVar->add(*savedInvVar_, movingAvgFraction_, 1.0 - movingAvgFraction_);
72-
} else {
73-
// movingMean = movingMean * movingAvgFraction_
74-
// + savedMean_ * (1 - movingAvgFraction_)
75-
movingMean->add(*savedMean_, movingAvgFraction_, 1.0 - movingAvgFraction_);
76-
// movingVar = movingVar * movingAvgFraction_
77-
// + savedInvVar_ * (1 - movingAvgFraction_)
78-
movingVar->add(*savedInvVar_, movingAvgFraction_, 1.0 - movingAvgFraction_);
79-
}
62+
auto& movingMean = movingMean_->getW();
63+
auto& movingVar = movingVar_->getW();
64+
// movingMean = movingMean * movingAvgFraction_
65+
// + savedMean_ * (1 - movingAvgFraction_)
66+
movingMean->add(*savedMean_, movingAvgFraction_, 1.0 - movingAvgFraction_);
67+
// movingVar = movingVar * movingAvgFraction_
68+
// + savedInvVar_ * (1 - movingAvgFraction_)
69+
movingVar->add(*savedInvVar_, movingAvgFraction_, 1.0 - movingAvgFraction_);
8070
}
8171

8272
void BatchNormalizationLayer::setMeanAndStd() {

paddle/math/Matrix.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1973,8 +1973,8 @@ class SharedCpuMatrix : public CpuMatrix {
19731973

19741974
public:
19751975
virtual void mul(CpuSparseMatrix* a, CpuMatrix* b, real scaleAB, real scaleT);
1976-
void add(Matrix& b, real p1, real p2);
1977-
void add(real p1, real p2);
1976+
virtual void add(Matrix& b, real p1, real p2);
1977+
virtual void add(real p1, real p2);
19781978

19791979
private:
19801980
using Matrix::mul;

0 commit comments

Comments
 (0)