Skip to content

Commit 148bd4d

Browse files
committed
add Layer::createFunction
1 parent cee9346 commit 148bd4d

File tree

2 files changed

+35
-23
lines changed

2 files changed

+35
-23
lines changed

paddle/gserver/layers/Layer.h

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,9 @@ class Layer {
102102
std::vector<bool> markInBackward_;
103103

104104
/// Layer forward function
105-
FunctionBase* forward_;
105+
std::vector<std::shared_ptr<FunctionBase>> forward_;
106106
/// Layer backward function
107-
FunctionBase* backward_;
107+
std::vector<std::shared_ptr<FunctionBase>> backward_;
108108

109109
public:
110110
/**
@@ -132,6 +132,26 @@ class Layer {
132132
virtual void markAllInputGrad();
133133

134134
protected:
135+
/**
136+
* Create layer function. Function is called in forward or backward.
137+
* \param function, Layer::forward_ or Layer::backward_
138+
* \param name, function name
139+
* \param config, initialization configuration for the function
140+
*/
141+
void createFunction(std::vector<std::shared_ptr<FunctionBase>>& function,
142+
const std::string& name,
143+
const FuncConfig& config) {
144+
if (useGpu_) {
145+
function.emplace_back(
146+
FunctionBase::funcRegistrar_.createByType(name + "-GPU"));
147+
} else {
148+
function.emplace_back(
149+
FunctionBase::funcRegistrar_.createByType(name + "-CPU"));
150+
}
151+
auto& func = function.back();
152+
func->init(config);
153+
}
154+
135155
/**
136156
* Notify specified layer the output grad ready.
137157
* Called in the backward function.

paddle/gserver/layers/NormProjectionLayer.cpp

Lines changed: 13 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -45,21 +45,13 @@ bool CMRProjectionNormLayer::init(const LayerMap& layerMap,
4545
/* the size of inputs for norm-layer is 1 */
4646
CHECK_EQ(config_.inputs_size(), 1);
4747

48-
if (useGpu_) {
49-
forward_ = FunctionBase::funcRegistrar_.createByType(
50-
FUNC_NAME(CrossMapNormal, GPU));
51-
backward_ = FunctionBase::funcRegistrar_.createByType(
52-
FUNC_NAME(CrossMapNormalGrad, GPU));
53-
} else {
54-
forward_ = FunctionBase::funcRegistrar_.createByType(
55-
FUNC_NAME(CrossMapNormal, CPU));
56-
backward_ = FunctionBase::funcRegistrar_.createByType(
57-
FUNC_NAME(CrossMapNormalGrad, CPU));
58-
}
59-
forward_->init(
48+
createFunction(
49+
forward_,
50+
"CrossMapNormal",
6051
FuncConfig().set("size", size_).set("scale", scale_).set("pow", pow_));
61-
62-
backward_->init(
52+
createFunction(
53+
backward_,
54+
"CrossMapNormalGrad",
6355
FuncConfig().set("size", size_).set("scale", scale_).set("pow", pow_));
6456

6557
return true;
@@ -80,7 +72,7 @@ void CMRProjectionNormLayer::forward(PassType passType) {
8072
Matrix::resizeOrCreate(denoms_, batchSize, size, /* trans */ false, useGpu_);
8173

8274
dims_ = {batchSize, channels_, imgSizeH_, imgSizeW_};
83-
forward_->calc(
75+
forward_[0]->calc(
8476
{Tensor(input->getData(), dims_)},
8577
{Tensor(outV->getData(), dims_), Tensor(denoms_->getData(), dims_)},
8678
{});
@@ -98,11 +90,11 @@ void CMRProjectionNormLayer::backward(const UpdateCallback& callback) {
9890
MatrixPtr localOutV = getOutputValue();
9991
MatrixPtr preOutV = inputLayers_[0]->getOutputValue();
10092

101-
backward_->calc({Tensor(preOutV->getData(), dims_),
102-
Tensor(localOutV->getData(), dims_),
103-
Tensor(localGrad->getData(), dims_),
104-
Tensor(denoms_->getData(), dims_)},
105-
{Tensor(preOutGrad->getData(), dims_)},
106-
{});
93+
backward_[0]->calc({Tensor(preOutV->getData(), dims_),
94+
Tensor(localOutV->getData(), dims_),
95+
Tensor(localGrad->getData(), dims_),
96+
Tensor(denoms_->getData(), dims_)},
97+
{Tensor(preOutGrad->getData(), dims_)},
98+
{});
10799
}
108100
} // namespace paddle

0 commit comments

Comments
 (0)