@@ -45,21 +45,13 @@ bool CMRProjectionNormLayer::init(const LayerMap& layerMap,
4545 /* the size of inputs for norm-layer is 1 */
4646 CHECK_EQ (config_.inputs_size (), 1 );
4747
48- if (useGpu_) {
49- forward_ = FunctionBase::funcRegistrar_.createByType (
50- FUNC_NAME (CrossMapNormal, GPU));
51- backward_ = FunctionBase::funcRegistrar_.createByType (
52- FUNC_NAME (CrossMapNormalGrad, GPU));
53- } else {
54- forward_ = FunctionBase::funcRegistrar_.createByType (
55- FUNC_NAME (CrossMapNormal, CPU));
56- backward_ = FunctionBase::funcRegistrar_.createByType (
57- FUNC_NAME (CrossMapNormalGrad, CPU));
58- }
59- forward_->init (
48+ createFunction (
49+ forward_,
50+ " CrossMapNormal" ,
6051 FuncConfig ().set (" size" , size_).set (" scale" , scale_).set (" pow" , pow_));
61-
62- backward_->init (
52+ createFunction (
53+ backward_,
54+ " CrossMapNormalGrad" ,
6355 FuncConfig ().set (" size" , size_).set (" scale" , scale_).set (" pow" , pow_));
6456
6557 return true ;
@@ -80,7 +72,7 @@ void CMRProjectionNormLayer::forward(PassType passType) {
8072 Matrix::resizeOrCreate (denoms_, batchSize, size, /* trans */ false , useGpu_);
8173
8274 dims_ = {batchSize, channels_, imgSizeH_, imgSizeW_};
83- forward_->calc (
75+ forward_[ 0 ] ->calc (
8476 {Tensor (input->getData (), dims_)},
8577 {Tensor (outV->getData (), dims_), Tensor (denoms_->getData (), dims_)},
8678 {});
@@ -98,11 +90,11 @@ void CMRProjectionNormLayer::backward(const UpdateCallback& callback) {
9890 MatrixPtr localOutV = getOutputValue ();
9991 MatrixPtr preOutV = inputLayers_[0 ]->getOutputValue ();
10092
101- backward_->calc ({Tensor (preOutV->getData (), dims_),
102- Tensor (localOutV->getData (), dims_),
103- Tensor (localGrad->getData (), dims_),
104- Tensor (denoms_->getData (), dims_)},
105- {Tensor (preOutGrad->getData (), dims_)},
106- {});
93+ backward_[ 0 ] ->calc ({Tensor (preOutV->getData (), dims_),
94+ Tensor (localOutV->getData (), dims_),
95+ Tensor (localGrad->getData (), dims_),
96+ Tensor (denoms_->getData (), dims_)},
97+ {Tensor (preOutGrad->getData (), dims_)},
98+ {});
10799}
108100} // namespace paddle
0 commit comments