Skip to content

Commit 9049369

Browse files
authored
Merge pull request #934 from tianbingsz/paddle_function_mat
Matrix API refactor
2 parents dadd48a + 4fbf949 commit 9049369

23 files changed

+247
-290
lines changed

paddle/gserver/evaluators/Evaluator.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ class ClassificationErrorEvaluator : public Evaluator {
7878
useGpu(arguments[0].deviceId));
7979
errorMat->zeroMem();
8080
if (label != nullptr) {
81-
errorMat->classificationError(output, label);
81+
errorMat->classificationError(*output, *label);
8282
} else if (dynamic_cast<CpuSparseMatrix*>(multiBinaryLabel.get()) ||
8383
dynamic_cast<GpuSparseMatrix*>(multiBinaryLabel.get())) {
8484
errorMat->classificationErrorMulti(

paddle/gserver/layers/ContextProjection.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,8 @@ void ContextProjection::forward() {
9090
REGISTER_TIMER_INFO("ContextProjectionForward", getName().c_str());
9191
bool isPadding = config_.trainable_padding();
9292
out_->value->contextProjectionForward(
93-
in_->value,
94-
state_ ? state_ : isPadding ? weight_->getW() : nullptr,
93+
*(in_->value),
94+
state_ ? state_.get() : isPadding ? weight_->getW().get() : nullptr,
9595
*startPositions,
9696
config_.context_length(),
9797
config_.context_start(),
@@ -128,24 +128,24 @@ void ContextProjection::backward(const UpdateCallback& callback) {
128128
bool isPadding = config_.trainable_padding();
129129
if (!out_->grad->useGpu()) {
130130
out_->grad->contextProjectionBackward(
131-
in_->grad,
132-
isPadding ? weight_->getWGrad() : nullptr,
131+
in_->grad.get(),
132+
isPadding ? weight_->getWGrad().get() : nullptr,
133133
*startPositions,
134134
config_.context_length(),
135135
config_.context_start(),
136136
beginPad_,
137137
isPadding);
138138
} else {
139139
if (in_->grad) {
140-
out_->grad->contextProjectionBackwardData(in_->grad,
140+
out_->grad->contextProjectionBackwardData(*(in_->grad),
141141
*startPositions,
142142
config_.context_length(),
143143
config_.context_start());
144144
}
145145

146146
if (isPadding && weight_->getWGrad()) {
147147
out_->grad->contextProjectionBackwardWeight(
148-
weight_->getWGrad(),
148+
*(weight_->getWGrad()),
149149
*startPositions,
150150
config_.context_length(),
151151
config_.context_start(),

paddle/gserver/layers/ConvexCombinationLayer.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ void ConvexCombinationLayer::forward(PassType passType) {
113113
tmpRow0->setData(inV0->getData() + i * weightDim);
114114
tmpRow1->setData(outV->getData() + i * dataDim);
115115

116-
tmpRow1->mul(tmpRow0, tmpMtx0, 1, 0);
116+
tmpRow1->mul(*tmpRow0, *tmpMtx0, 1, 0);
117117
}
118118
}
119119

@@ -136,7 +136,7 @@ void ConvexCombinationLayer::backward(const UpdateCallback& callback) {
136136
tmpRow1->setData(outG->getData() + i * dataDim);
137137
tmpMtx0->setData(inV1->getData() + i * weightDim * dataDim);
138138

139-
tmpRow0->mul(tmpRow1, tmpMtx0->getTranspose(), 1, 1);
139+
tmpRow0->mul(*tmpRow1, *(tmpMtx0->getTranspose()), 1, 1);
140140
}
141141
}
142142

@@ -146,7 +146,7 @@ void ConvexCombinationLayer::backward(const UpdateCallback& callback) {
146146
tmpRow1->setData(outG->getData() + i * dataDim);
147147
tmpMtx0->setData(inG1->getData() + i * weightDim * dataDim);
148148

149-
tmpMtx0->mul(tmpRow0->getTranspose(), tmpRow1, 1, 1);
149+
tmpMtx0->mul(*(tmpRow0->getTranspose()), *tmpRow1, 1, 1);
150150
}
151151
}
152152
}

paddle/gserver/layers/ExpandConvBaseLayer.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ void ExpandConvBaseLayer::expandFwdOnce(MatrixPtr image,
150150
Matrix::create(wgtData, subM, subK, false, useGpu_); // mark transpose
151151
MatrixPtr B = Matrix::create(expInData, subK, subN, false, useGpu_);
152152
MatrixPtr C = Matrix::create(outData, subM, subN, false, useGpu_);
153-
C->mul(A, B, 1, 1);
153+
C->mul(*A, *B, 1, 1);
154154

155155
A->clear();
156156
B->clear();
@@ -185,7 +185,7 @@ void ExpandConvBaseLayer::bpropActs(MatrixPtr out,
185185
MatrixPtr C = Matrix::create(expandInData, subK, subN, false, useGpu_);
186186
MatrixPtr B = Matrix::create(localGradData, subM, subN, false, useGpu_);
187187
MatrixPtr A = Matrix::create(wgtData, subM, subK, true, useGpu_);
188-
C->mul(A, B); // mul
188+
C->mul(*A, *B); // mul
189189

190190
// clear the temporary matrix
191191
A->clear();
@@ -252,7 +252,7 @@ void ExpandConvBaseLayer::bpropWeights(MatrixPtr image,
252252
MatrixPtr A = Matrix::create(expandInData, subK, subN, true, useGpu_);
253253
MatrixPtr B = Matrix::create(gradData, subM, subN, false, useGpu_);
254254
MatrixPtr C = Matrix::create(wGradData, subM, subK, false, useGpu_);
255-
C->mul(B, A, 1, 1);
255+
C->mul(*B, *A, 1, 1);
256256

257257
A->clear();
258258
B->clear();

paddle/gserver/layers/FullMatrixProjection.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ FullMatrixProjection::FullMatrixProjection(const ProjectionConfig& config,
2828

2929
void FullMatrixProjection::forward() {
3030
REGISTER_TIMER_INFO("FwMulTimer", getName().c_str());
31-
out_->value->mul(in_->value, weight_->getW(), 1, 1);
31+
out_->value->mul(*(in_->value), *(weight_->getW()), 1, 1);
3232
}
3333

3434
void FullMatrixProjection::backward(const UpdateCallback& callback) {
@@ -37,7 +37,8 @@ void FullMatrixProjection::backward(const UpdateCallback& callback) {
3737
/* Calculate the W-gradient for the current layer */
3838
if (weight_->getWGrad()) {
3939
REGISTER_TIMER_INFO("GradMulTimer", getName().c_str());
40-
weight_->getWGrad()->mul(in_->value->getTranspose(), out_->grad, 1, 1);
40+
weight_->getWGrad()->mul(
41+
*(in_->value->getTranspose()), *(out_->grad), 1, 1);
4142
}
4243

4344
// If callback does not change value, backward propagation error
@@ -47,7 +48,7 @@ void FullMatrixProjection::backward(const UpdateCallback& callback) {
4748
/* Calculate the input layers error */
4849
if (in_->grad) {
4950
REGISTER_TIMER_INFO("BpMulTimer", getName().c_str());
50-
in_->grad->mul(out_->grad, weight_->getW()->getTranspose(), 1, 1);
51+
in_->grad->mul(*(out_->grad), *(weight_->getW()->getTranspose()), 1, 1);
5152
}
5253

5354
hl_set_sync_flag(syncFlag);

paddle/gserver/layers/FullyConnectedLayer.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@ void FullyConnectedLayer::forward(PassType passType) {
8484
auto input = getInput(i);
8585
CHECK(input.value) << "The input of 'fc' layer must be matrix";
8686
REGISTER_TIMER_INFO("FwMulTimer", getName().c_str());
87-
i == 0 ? outV->mul(input.value, weights_[i]->getW(), 1, 0)
88-
: outV->mul(input.value, weights_[i]->getW(), 1, 1);
87+
i == 0 ? outV->mul(*input.value, *weights_[i]->getW(), 1, 0)
88+
: outV->mul(*input.value, *weights_[i]->getW(), 1, 1);
8989
}
9090

9191
/* add the bias-vector */
@@ -123,7 +123,7 @@ void FullyConnectedLayer::backward(const UpdateCallback& callback) {
123123
MatrixPtr oGrad = getOutputGrad();
124124
{
125125
REGISTER_TIMER_INFO("GradMulTimer", getName().c_str());
126-
weights_[i]->getWGrad()->mul(input_T, oGrad, 1, 1);
126+
weights_[i]->getWGrad()->mul(*input_T, *oGrad, 1, 1);
127127
}
128128
}
129129

@@ -136,7 +136,7 @@ void FullyConnectedLayer::backward(const UpdateCallback& callback) {
136136
if (NULL != preGrad) {
137137
MatrixPtr weights_T = weights_[i]->getW()->getTranspose();
138138
REGISTER_TIMER_INFO("BpMulTimer", getName().c_str());
139-
preGrad->mul(getOutputGrad(), weights_T, 1, 1);
139+
preGrad->mul(*getOutputGrad(), *weights_T, 1, 1);
140140
}
141141

142142
hl_set_sync_flag(syncFlag);

paddle/gserver/layers/LinearChainCRF.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ real LinearChainCRF::forward(real* x, int* s, int length) {
5959
matX->rowMax(*maxX_);
6060
expX_->assign(*matX);
6161
// subtract max to avoid overflow or underflow
62-
expX_->mul(maxX_, ones_, (real)-1, (real)1);
62+
expX_->mul(*maxX_, *ones_, (real)-1, (real)1);
6363
expX_->exp2();
6464

6565
real* a = a_->getData();

paddle/gserver/layers/LstmLayer.cpp

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ void LstmLayer::forwardSequence(int batchSize,
316316
}
317317
if (prevOutput_) {
318318
frameGate->setData(lstmValue.gateValue);
319-
frameGate->mul(prevOutput_, weight_->getW(), 1, 1);
319+
frameGate->mul(*prevOutput_, *weight_->getW(), 1, 1);
320320
}
321321
}
322322
AsyncGpuBlock asyncGpuBlock;
@@ -338,7 +338,7 @@ void LstmLayer::forwardSequence(int batchSize,
338338
frameOutput->setData(lstmValue.outputValue);
339339
nextFrame(reversed_, getSize());
340340
frameGate->setData(lstmValue.gateValue);
341-
frameGate->mul(frameOutput, weight_->getW(), 1, 1);
341+
frameGate->mul(*frameOutput, *weight_->getW(), 1, 1);
342342
}
343343
}
344344
if (n != numSequences - 1) {
@@ -348,7 +348,7 @@ void LstmLayer::forwardSequence(int batchSize,
348348
if (!reversed_) {
349349
if (!prevState_) lstmValue.prevStateValue = nullptr;
350350
if (prevOutput_) {
351-
frameGate->mul(frameOutput, weight_->getW(), 1, 1);
351+
frameGate->mul(*frameOutput, *weight_->getW(), 1, 1);
352352
}
353353
} else {
354354
lstmValue.prevStateValue = nullptr;
@@ -470,7 +470,7 @@ void LstmLayer::backwardSequence(int batchSize,
470470
frameGate->setData(lstmGrad.gateGrad);
471471
nextFrame(reversed_, getSize());
472472
frameOutput->setData(lstmGrad.outputGrad);
473-
frameOutput->mul(frameGate, weightT, 1, 1);
473+
frameOutput->mul(*frameGate, *weightT, 1, 1);
474474
} else {
475475
nextFrame(reversed_, getSize());
476476
}
@@ -479,14 +479,14 @@ void LstmLayer::backwardSequence(int batchSize,
479479
if (weight_->getWGrad()) {
480480
if (!reversed_) {
481481
weight_->getWGrad()->mul(
482-
output_.value->subMatrix(start, length - 1)->getTranspose(),
483-
gate_.grad->subMatrix(start + 1, length - 1),
482+
*output_.value->subMatrix(start, length - 1)->getTranspose(),
483+
*gate_.grad->subMatrix(start + 1, length - 1),
484484
1,
485485
1);
486486
} else {
487487
weight_->getWGrad()->mul(
488-
output_.value->subMatrix(start + 1, length - 1)->getTranspose(),
489-
gate_.grad->subMatrix(start, length - 1),
488+
*output_.value->subMatrix(start + 1, length - 1)->getTranspose(),
489+
*gate_.grad->subMatrix(start, length - 1),
490490
1,
491491
1);
492492
}
@@ -541,15 +541,15 @@ void LstmLayer::forwardBatch(int batchSize,
541541

542542
if (n != 0) {
543543
MatrixPtr batch1 = batchValue_->getBatchValue(n - 1, batchSize);
544-
gateValue->mul(batch1, weight_->getW(), 1, 1);
544+
gateValue->mul(*batch1, *weight_->getW(), 1, 1);
545545
} else if (prevOutput_) {
546546
Matrix::resizeOrCreate(prevBatchOutput2_,
547547
gateValue->getHeight(),
548548
getSize(),
549549
false,
550550
useGpu_);
551551
batchValue_->prevOutput2Batch(*prevOutput_, *prevBatchOutput2_);
552-
gateValue->mul(prevBatchOutput2_, weight_->getW(), 1, 1);
552+
gateValue->mul(*prevBatchOutput2_, *weight_->getW(), 1, 1);
553553

554554
batchValue_->prevOutput2Batch(*prevState_,
555555
*totalState_->subMatrix(0, numSequences));
@@ -672,16 +672,16 @@ void LstmLayer::backwardBatch(int batchSize,
672672

673673
if (n != 0) {
674674
MatrixPtr tmp = batchGrad_->getBatchValue(n - 1, batchSize);
675-
tmp->mul(gateGrad, weightT, 1, 1);
675+
tmp->mul(*gateGrad, *weightT, 1, 1);
676676
}
677677

678678
if (n != 0 && weight_->getWGrad()) {
679679
/* backward weight */
680680
MatrixPtr outputValue = batchValue_->getBatchValue(n - 1, batchSize);
681-
weight_->getWGrad()->mul(outputValue->getTranspose(), gateGrad, 1, 1);
681+
weight_->getWGrad()->mul(*outputValue->getTranspose(), *gateGrad, 1, 1);
682682
} else if (prevOutput_ && weight_->getWGrad()) {
683683
weight_->getWGrad()->mul(
684-
prevBatchOutput2_->getTranspose(), gateGrad, 1, 1);
684+
*prevBatchOutput2_->getTranspose(), *gateGrad, 1, 1);
685685
}
686686
}
687687
}

paddle/gserver/layers/MDLstmLayer.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -547,7 +547,7 @@ void MDLstmLayer::forwardOneSequence(int start, CoordIterator& coordIter) {
547547
if (coordIter.getPrePos(delays_, i, prePos)) {
548548
int preOffset = coordIter.offset(prePos);
549549
frameGate_[start + offset].value->mul(
550-
frameOutput_[start + preOffset].value, weight_->getW(), 1.0, 1.0);
550+
*frameOutput_[start + preOffset].value, *weight_->getW(), 1.0, 1.0);
551551
}
552552
}
553553
forwardGate2OutputSequence(start, coordIter);
@@ -747,11 +747,11 @@ void MDLstmLayer::backwardOneSequence(int start, CoordIterator& coordIter) {
747747
if (coordIter.getPrePos(delays_, i, prePos)) {
748748
int preOffset = coordIter.offset(prePos);
749749
frameOutput_[start + preOffset].grad->mul(
750-
frameGate_[start + offset].grad, weightT, 1.0, 1.0);
750+
*frameGate_[start + offset].grad, *weightT, 1.0, 1.0);
751751
if (weight_->getWGrad()) {
752752
weight_->getWGrad()->mul(
753-
frameOutput_[start + preOffset].value->getTranspose(),
754-
frameGate_[start + offset].grad,
753+
*frameOutput_[start + preOffset].value->getTranspose(),
754+
*frameGate_[start + offset].grad,
755755
1.0,
756756
1.0);
757757
}

paddle/gserver/layers/OuterProdLayer.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ void OuterProdLayer::forward(PassType passType) {
9696
tmpRow0->setData(inV0->getData() + i * dim0);
9797
tmpRow1->setData(inV1->getData() + i * dim1);
9898

99-
tmpMtx0->mul(tmpRow0->getTranspose(), tmpRow1);
99+
tmpMtx0->mul(*tmpRow0->getTranspose(), *tmpRow1);
100100
}
101101
}
102102
}
@@ -121,7 +121,7 @@ void OuterProdLayer::backward(const UpdateCallback& callback) {
121121
tmpRow0->setData(inG0->getData() + i * dim0);
122122
tmpRow1->setData(inV1->getData() + i * dim1);
123123

124-
tmpRow0->mul(tmpRow1, tmpMtx0->getTranspose(), 1, 1);
124+
tmpRow0->mul(*tmpRow1, *tmpMtx0->getTranspose(), 1, 1);
125125
}
126126
}
127127

@@ -131,7 +131,7 @@ void OuterProdLayer::backward(const UpdateCallback& callback) {
131131
tmpRow0->setData(inV0->getData() + i * dim0);
132132
tmpRow1->setData(inG1->getData() + i * dim1);
133133

134-
tmpRow1->mul(tmpRow0, tmpMtx0, 1, 1);
134+
tmpRow1->mul(*tmpRow0, *tmpMtx0, 1, 1);
135135
}
136136
}
137137
}

0 commit comments

Comments
 (0)