Skip to content

Commit 68156c8

Browse files
committed
Modify the argument type of Function
1 parent c5c8051 commit 68156c8

File tree

4 files changed

+56
-97
lines changed

4 files changed

+56
-97
lines changed

paddle/function/CrossMapNormalOp.cpp

Lines changed: 32 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -125,27 +125,25 @@ class CrossMapNormalFunc : public FunctionBase {
125125
pow_ = config.get<real>("pow");
126126
}
127127

128-
void calc(const Arguments& inputs,
129-
const Arguments& outputs,
130-
const Arguments& inouts) override {
128+
void calc(const BufferArgs& inputs,
129+
const BufferArgs& outputs,
130+
const BufferArgs& inouts) override {
131131
CHECK_EQ(1, inputs.size());
132132
CHECK_EQ(2, outputs.size());
133133
CHECK_EQ(0, inouts.size());
134134

135-
CHECK_EQ(inputs[0].dims_.size(), 4);
136-
for (size_t i = 0; i < inputs[0].dims_.size(); i++) {
137-
CHECK_EQ(inputs[0].dims_[i], outputs[0].dims_[i]);
138-
CHECK_EQ(inputs[0].dims_[i], outputs[1].dims_[i]);
139-
}
135+
CHECK_EQ(inputs[0].shape().ndims(), 4);
136+
CHECK(inputs[0].shape() == outputs[0].shape());
137+
CHECK(inputs[0].shape() == outputs[1].shape());
140138

141-
size_t samples = inputs[0].dims_[0];
142-
size_t channels = inputs[0].dims_[1];
143-
size_t height = inputs[0].dims_[2];
144-
size_t width = inputs[0].dims_[3];
139+
size_t samples = inputs[0].shape()[0];
140+
size_t channels = inputs[0].shape()[1];
141+
size_t height = inputs[0].shape()[2];
142+
size_t width = inputs[0].shape()[3];
145143

146-
CrossMapNormal<Device>(outputs[0].getData(),
147-
outputs[1].getData(),
148-
inputs[0].getData(),
144+
CrossMapNormal<Device>(outputs[0].data<real>(),
145+
outputs[1].data<real>(),
146+
inputs[0].data<real>(),
149147
samples,
150148
channels,
151149
height,
@@ -177,31 +175,29 @@ class CrossMapNormalGradFunc : public FunctionBase {
177175
pow_ = config.get<real>("pow");
178176
}
179177

180-
void calc(const Arguments& inputs,
181-
const Arguments& outputs,
182-
const Arguments& inouts) override {
178+
void calc(const BufferArgs& inputs,
179+
const BufferArgs& outputs,
180+
const BufferArgs& inouts) override {
183181
CHECK_EQ(4, inputs.size());
184182
CHECK_EQ(1, outputs.size());
185183
CHECK_EQ(0, inouts.size());
186184

187-
CHECK_EQ(inputs[0].dims_.size(), 4);
188-
for (size_t i = 0; i < inputs[0].dims_.size(); i++) {
189-
CHECK_EQ(inputs[0].dims_[i], inputs[1].dims_[i]);
190-
CHECK_EQ(inputs[0].dims_[i], inputs[2].dims_[i]);
191-
CHECK_EQ(inputs[0].dims_[i], inputs[3].dims_[i]);
192-
CHECK_EQ(inputs[0].dims_[i], outputs[0].dims_[i]);
193-
}
194-
195-
size_t samples = inputs[0].dims_[0];
196-
size_t channels = inputs[0].dims_[1];
197-
size_t height = inputs[0].dims_[2];
198-
size_t width = inputs[0].dims_[3];
199-
200-
CrossMapNormalGrad<Device>(outputs[0].getData(),
201-
inputs[0].getData(),
202-
inputs[1].getData(),
203-
inputs[2].getData(),
204-
inputs[3].getData(),
185+
CHECK_EQ(inputs[0].shape().ndims(), 4);
186+
CHECK(inputs[0].shape() == inputs[1].shape());
187+
CHECK(inputs[0].shape() == inputs[2].shape());
188+
CHECK(inputs[0].shape() == inputs[3].shape());
189+
CHECK(inputs[0].shape() == outputs[0].shape());
190+
191+
size_t samples = inputs[0].shape()[0];
192+
size_t channels = inputs[0].shape()[1];
193+
size_t height = inputs[0].shape()[2];
194+
size_t width = inputs[0].shape()[3];
195+
196+
CrossMapNormalGrad<Device>(outputs[0].data<real>(),
197+
inputs[0].data<real>(),
198+
inputs[1].data<real>(),
199+
inputs[2].data<real>(),
200+
inputs[3].data<real>(),
205201
samples,
206202
channels,
207203
height,

paddle/function/Function.h

Lines changed: 4 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -16,57 +16,12 @@ limitations under the License. */
1616

1717
#include <map>
1818
#include <vector>
19+
#include "BufferArg.h"
1920
#include "paddle/math/Matrix.h"
2021
#include "paddle/utils/ClassRegistrar.h"
2122

2223
namespace paddle {
2324

24-
enum DeviceType {
25-
DEVICE_TYPE_UNSPECIFIED = 0,
26-
DEVICE_TYPE_CPU = 1,
27-
DEVICE_TYPE_GPU = 2,
28-
};
29-
30-
template <DeviceType Device>
31-
struct MatrixT;
32-
33-
template <>
34-
struct MatrixT<DEVICE_TYPE_CPU> {
35-
using type = CpuMatrix;
36-
};
37-
38-
template <>
39-
struct MatrixT<DEVICE_TYPE_GPU> {
40-
using type = GpuMatrix;
41-
};
42-
43-
template <DeviceType Device>
44-
struct SequenceT;
45-
46-
template <>
47-
struct SequenceT<DEVICE_TYPE_CPU> {
48-
using type = CpuIVector;
49-
};
50-
51-
template <>
52-
struct SequenceT<DEVICE_TYPE_GPU> {
53-
using type = GpuIVector;
54-
};
55-
56-
typedef std::vector<size_t> Dims;
57-
58-
class Tensor {
59-
public:
60-
Tensor(real* data, const Dims& dim) : buf_(data), dims_(dim) {}
61-
62-
real* getData() const { return buf_; }
63-
64-
real* buf_;
65-
Dims dims_;
66-
};
67-
68-
typedef std::vector<Tensor> Arguments;
69-
7025
class FuncConfig {
7126
public:
7227
union value {
@@ -92,9 +47,9 @@ class FunctionBase {
9247

9348
virtual void init(const FuncConfig& config) {}
9449

95-
virtual void calc(const Arguments& inputs,
96-
const Arguments& outputs,
97-
const Arguments& inouts) {}
50+
virtual void calc(const BufferArgs& inputs,
51+
const BufferArgs& outputs,
52+
const BufferArgs& inouts) {}
9853

9954
static ClassRegistrar<FunctionBase> funcRegistrar_;
10055
};

paddle/gserver/layers/NormProjectionLayer.cpp

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,16 @@ void CMRProjectionNormLayer::forward(PassType passType) {
7171

7272
Matrix::resizeOrCreate(denoms_, batchSize, size, /* trans */ false, useGpu_);
7373

74-
dims_ = {batchSize, channels_, imgSizeH_, imgSizeW_};
75-
forward_[0]->calc(
76-
{Tensor(input->getData(), dims_)},
77-
{Tensor(outV->getData(), dims_), Tensor(denoms_->getData(), dims_)},
78-
{});
74+
shape_ = TensorShape({batchSize, channels_, imgSizeH_, imgSizeW_});
75+
76+
BufferArgs inputs;
77+
BufferArgs outputs;
78+
BufferArgs inouts;
79+
inputs.addArg(*input, shape_);
80+
outputs.addArg(*outV, shape_);
81+
outputs.addArg(*denoms_, shape_);
82+
83+
forward_[0]->calc(inputs, outputs, inouts);
7984
}
8085

8186
void CMRProjectionNormLayer::backward(const UpdateCallback& callback) {
@@ -90,11 +95,14 @@ void CMRProjectionNormLayer::backward(const UpdateCallback& callback) {
9095
MatrixPtr localOutV = getOutputValue();
9196
MatrixPtr preOutV = inputLayers_[0]->getOutputValue();
9297

93-
backward_[0]->calc({Tensor(preOutV->getData(), dims_),
94-
Tensor(localOutV->getData(), dims_),
95-
Tensor(localGrad->getData(), dims_),
96-
Tensor(denoms_->getData(), dims_)},
97-
{Tensor(preOutGrad->getData(), dims_)},
98-
{});
98+
BufferArgs inputs;
99+
BufferArgs outputs;
100+
BufferArgs inouts;
101+
inputs.addArg(*preOutV, shape_);
102+
inputs.addArg(*localOutV, shape_);
103+
inputs.addArg(*localGrad, shape_);
104+
inputs.addArg(*denoms_, shape_);
105+
outputs.addArg(*preOutGrad, shape_);
106+
backward_[0]->calc(inputs, outputs, inouts);
99107
}
100108
} // namespace paddle

paddle/gserver/layers/NormProjectionLayer.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,6 @@ class CMRProjectionNormLayer : public ResponseNormLayer {
4141
void backward(const UpdateCallback& callback = nullptr);
4242

4343
protected:
44-
Dims dims_;
44+
TensorShape shape_;
4545
};
4646
} // namespace paddle

0 commit comments

Comments
 (0)