@@ -19,17 +19,15 @@ limitations under the License. */
1919namespace paddle {
2020
2121template <>
22- void ContextProjectionForward<DEVICE_TYPE_CPU>(CpuMatrix* out_mat,
23- const CpuMatrix* input_mat,
24- const CpuMatrix* weight_mat,
22+ void ContextProjectionForward<DEVICE_TYPE_CPU>(CpuMatrix& out_mat,
23+ const CpuMatrix& input_mat,
24+ const CpuMatrix& weight_mat,
2525 const CpuIVector& seq_vec,
2626 size_t context_length,
2727 int context_start,
2828 size_t begin_pad) {
2929 const int * starts = seq_vec.getData ();
3030 const size_t num_sequences = seq_vec.getSize () - 1 ;
31- auto w_mat = const_cast <CpuMatrix*>(weight_mat);
32- auto in_mat = const_cast <CpuMatrix*>(input_mat);
3331 for (size_t i = 0 ; i < num_sequences; ++i) {
3432 for (size_t j = 0 ; j < context_length; ++j) {
3533 int begin = starts[i] + context_start + j;
@@ -39,30 +37,34 @@ void ContextProjectionForward<DEVICE_TYPE_CPU>(CpuMatrix* out_mat,
3937 if (begin < starts[i]) {
4038 int64_t pad_size =
4139 std::min (starts[i] - begin, starts[i + 1 ] - starts[i]);
42- MatrixPtr mat = out_mat->subMatrix (starts[i], pad_size);
43- if (w_mat) {
44- MatrixPtr sub = w_mat->subMatrix (j, pad_size);
45- mat->addAtOffset (*sub, j * in_mat->getWidth ());
40+ MatrixPtr mat = out_mat.subMatrix (starts[i], pad_size);
41+ if (weight_mat) {
42+ MatrixPtr sub =
43+ const_cast <CpuMatrix&>(weight_mat).subMatrix (j, pad_size);
44+ mat->addAtOffset (*sub, j * input_mat.getWidth ());
4645 }
4746 dst_begin = starts[i] + pad_size;
4847 begin = starts[i];
4948 }
5049 if (end > starts[i + 1 ]) {
5150 int64_t pad_size =
5251 std::min (end - starts[i + 1 ], starts[i + 1 ] - starts[i]);
53- MatrixPtr mat = out_mat->subMatrix (starts[i + 1 ] - pad_size, pad_size);
54- if (w_mat) {
55- MatrixPtr sub = w_mat->subMatrix (
56- begin_pad + context_start + j - pad_size, pad_size);
57- mat->addAtOffset (*sub, j * in_mat->getWidth ());
52+ MatrixPtr mat = out_mat.subMatrix (starts[i + 1 ] - pad_size, pad_size);
53+ if (weight_mat) {
54+ MatrixPtr sub =
55+ const_cast <CpuMatrix&>(weight_mat)
56+ .subMatrix (begin_pad + context_start + j - pad_size,
57+ pad_size);
58+ mat->addAtOffset (*sub, j * input_mat.getWidth ());
5859 }
5960 dst_end = starts[i + 1 ] - pad_size;
6061 end = starts[i + 1 ];
6162 }
6263 if (end <= begin) continue ;
63- MatrixPtr src = in_mat->subMatrix (begin, end - begin);
64- MatrixPtr dst = out_mat->subMatrix (dst_begin, dst_end - dst_begin);
65- dst->addAtOffset (*src, j * in_mat->getWidth ());
64+ MatrixPtr src =
65+ const_cast <CpuMatrix&>(input_mat).subMatrix (begin, end - begin);
66+ MatrixPtr dst = out_mat.subMatrix (dst_begin, dst_end - dst_begin);
67+ dst->addAtOffset (*src, j * input_mat.getWidth ());
6668 }
6769 }
6870}
@@ -82,40 +84,34 @@ class ContextProjectionForwardFunc : public FunctionBase {
8284 begin_pad_ = config.get <size_t >(" begin_pad" );
8385 }
8486
85- void calc (const Arguments & inputs,
86- const Arguments & outputs,
87- const Arguments & inouts) override {
87+ void calc (const BufferArgs & inputs,
88+ const BufferArgs & outputs,
89+ const BufferArgs & inouts) override {
8890 CHECK_EQ (3 , inputs.size ());
8991 CHECK_EQ (1 , outputs.size ());
9092 CHECK_EQ (0 , inouts.size ());
9193
92- CHECK (outputs[0 ].getData () && inputs[0 ].getData () && inputs[2 ].getData ());
93- CHECK_EQ (outputs[0 ].dims_ . size (), 2 );
94- CHECK_EQ (inputs[0 ].dims_ . size (), 2 );
95- CHECK_EQ (inputs[1 ].dims_ . size (), 2 );
96- CHECK_EQ (inputs[2 ].dims_ . size (), 1 );
94+ CHECK (outputs[0 ].data () && inputs[0 ].data () && inputs[2 ].data ());
95+ CHECK_EQ (outputs[0 ].shape (). ndims (), 2 );
96+ CHECK_EQ (inputs[0 ].shape (). ndims (), 2 );
97+ CHECK_EQ (inputs[1 ].shape (). ndims (), 2 );
98+ CHECK_EQ (inputs[2 ].shape (). ndims (), 1 );
9799 // / dim of output = dim of input * context_length
98- CHECK_EQ (outputs[0 ].dims_ [1 ], inputs[0 ].dims_ [1 ] * context_length_);
100+ CHECK_EQ (outputs[0 ].shape () [1 ], inputs[0 ].shape () [1 ] * context_length_);
99101 // / dim of input == dim of weight
100- CHECK_EQ (inputs[0 ].dims_ [1 ], inputs[1 ].dims_ [1 ]);
102+ CHECK_EQ (inputs[0 ].shape () [1 ], inputs[1 ].shape () [1 ]);
101103 // / input and output has the same batch_size
102- CHECK_EQ (inputs[0 ].dims_ [0 ], outputs[0 ].dims_ [0 ]);
103-
104- auto out_mat = std::make_shared<typename MatrixT<Device>::type>(
105- outputs[0 ].getData (), outputs[0 ].dims_ [0 ], outputs[0 ].dims_ [1 ]);
106- const auto in_mat = std::make_shared<typename MatrixT<Device>::type>(
107- inputs[0 ].getData (), inputs[0 ].dims_ [0 ], inputs[0 ].dims_ [1 ]);
108- const auto w_mat =
109- !inputs[1 ].getData ()
110- ? nullptr
111- : std::make_shared<typename MatrixT<Device>::type>(
112- inputs[1 ].getData (), inputs[1 ].dims_ [0 ], inputs[1 ].dims_ [1 ]);
113- typename SequenceT<Device>::type seq_vec (
114- inputs[2 ].dims_ [0 ], reinterpret_cast <int *>(inputs[2 ].getData ()));
115-
116- ContextProjectionForward<Device>(out_mat.get (),
117- in_mat.get (),
118- w_mat.get (),
104+ CHECK_EQ (inputs[0 ].shape ()[0 ], outputs[0 ].shape ()[0 ]);
105+
106+ auto out_mat = outputs[0 ].matrix <Device>();
107+ auto in_mat = inputs[0 ].matrix <Device>();
108+ auto w_mat = !inputs[1 ].data ()
109+ ? typename Tensor<real, Device>::Matrix (nullptr , 0 , 0 )
110+ : inputs[1 ].matrix <Device>();
111+ auto seq_vec = inputs[2 ].vector <int , Device>();
112+ ContextProjectionForward<Device>(out_mat,
113+ in_mat,
114+ w_mat,
119115 seq_vec,
120116 context_length_,
121117 context_start_,
@@ -129,18 +125,17 @@ class ContextProjectionForwardFunc : public FunctionBase {
129125};
130126
131127template <>
132- void ContextProjectionBackward<DEVICE_TYPE_CPU>(CpuMatrix* out_grad_mat,
133- CpuMatrix* in_grad_mat,
134- CpuMatrix* w_grad_mat,
128+ void ContextProjectionBackward<DEVICE_TYPE_CPU>(CpuMatrix& out_grad_mat,
129+ CpuMatrix& in_grad_mat,
130+ CpuMatrix& w_grad_mat,
135131 const CpuIVector& seq_vec,
136132 size_t context_length,
137133 int context_start,
138134 size_t begin_pad,
139135 bool is_padding,
140136 size_t total_pad) {
141- CHECK (out_grad_mat);
142- size_t input_dim = in_grad_mat ? in_grad_mat->getWidth ()
143- : w_grad_mat ? w_grad_mat->getWidth () : 0 ;
137+ size_t input_dim = in_grad_mat ? in_grad_mat.getWidth ()
138+ : w_grad_mat ? w_grad_mat.getWidth () : 0 ;
144139 const int * starts = seq_vec.getData ();
145140 size_t num_sequences = seq_vec.getSize () - 1 ;
146141 for (size_t i = 0 ; i < num_sequences; ++i) {
@@ -153,8 +148,8 @@ void ContextProjectionBackward<DEVICE_TYPE_CPU>(CpuMatrix* out_grad_mat,
153148 int64_t pad_size =
154149 std::min (starts[i] - begin, starts[i + 1 ] - starts[i]);
155150 if (is_padding && w_grad_mat) {
156- MatrixPtr mat = out_grad_mat-> subMatrix (starts[i], pad_size);
157- MatrixPtr sub = w_grad_mat-> subMatrix (j, pad_size);
151+ MatrixPtr mat = out_grad_mat. subMatrix (starts[i], pad_size);
152+ MatrixPtr sub = w_grad_mat. subMatrix (j, pad_size);
158153 sub->addAtOffset (*mat, j * input_dim);
159154 }
160155 dst_begin = starts[i] + pad_size;
@@ -165,8 +160,8 @@ void ContextProjectionBackward<DEVICE_TYPE_CPU>(CpuMatrix* out_grad_mat,
165160 std::min (end - starts[i + 1 ], starts[i + 1 ] - starts[i]);
166161 if (is_padding && w_grad_mat) {
167162 MatrixPtr mat =
168- out_grad_mat-> subMatrix (starts[i + 1 ] - pad_size, pad_size);
169- MatrixPtr sub = w_grad_mat-> subMatrix (
163+ out_grad_mat. subMatrix (starts[i + 1 ] - pad_size, pad_size);
164+ MatrixPtr sub = w_grad_mat. subMatrix (
170165 begin_pad + context_start + j - pad_size, pad_size);
171166 sub->addAtOffset (*mat, j * input_dim);
172167 }
@@ -175,8 +170,8 @@ void ContextProjectionBackward<DEVICE_TYPE_CPU>(CpuMatrix* out_grad_mat,
175170 }
176171 if (end <= begin) continue ;
177172 if (!in_grad_mat) continue ;
178- MatrixPtr src = in_grad_mat-> subMatrix (begin, end - begin);
179- MatrixPtr dst = out_grad_mat-> subMatrix (dst_begin, dst_end - dst_begin);
173+ MatrixPtr src = in_grad_mat. subMatrix (begin, end - begin);
174+ MatrixPtr dst = out_grad_mat. subMatrix (dst_begin, dst_end - dst_begin);
180175 src->addAtOffset (*dst, j * input_dim);
181176 }
182177 }
@@ -199,44 +194,37 @@ class ContextProjectionBackwardFunc : public FunctionBase {
199194 total_pad_ = config.get <size_t >(" total_pad" );
200195 }
201196
202- void calc (const Arguments & inputs,
203- const Arguments & outputs,
204- const Arguments & inouts) override {
197+ void calc (const BufferArgs & inputs,
198+ const BufferArgs & outputs,
199+ const BufferArgs & inouts) override {
205200 CHECK_EQ (3 , inputs.size ());
206201 CHECK_EQ (1 , outputs.size ());
207202 CHECK_EQ (0 , inouts.size ());
208203
209- CHECK (outputs[0 ].getData () && inputs[2 ].getData ());
210- CHECK_EQ (outputs[0 ].dims_ . size (), 2 );
211- CHECK_EQ (inputs[0 ].dims_ . size (), 2 );
212- CHECK_EQ (inputs[1 ].dims_ . size (), 2 );
213- CHECK_EQ (inputs[2 ].dims_ . size (), 1 );
204+ CHECK (outputs[0 ].data () && inputs[2 ].data ());
205+ CHECK_EQ (outputs[0 ].shape (). ndims (), 2 );
206+ CHECK_EQ (inputs[0 ].shape (). ndims (), 2 );
207+ CHECK_EQ (inputs[1 ].shape (). ndims (), 2 );
208+ CHECK_EQ (inputs[2 ].shape (). ndims (), 1 );
214209
215210 // / dim of input == dim of weight
216- CHECK_EQ (inputs[0 ].dims_ [1 ], inputs[1 ].dims_ [1 ]);
211+ CHECK_EQ (inputs[0 ].shape () [1 ], inputs[1 ].shape () [1 ]);
217212 // / input and output has the same batch_size
218- CHECK_EQ (inputs[0 ].dims_ [0 ], outputs[0 ].dims_ [0 ]);
213+ CHECK_EQ (inputs[0 ].shape () [0 ], outputs[0 ].shape () [0 ]);
219214 // / dim of output = dim of input * context_length
220- CHECK_EQ (outputs[0 ].dims_ [1 ], inputs[0 ].dims_ [1 ] * context_length_);
215+ CHECK_EQ (outputs[0 ].shape () [1 ], inputs[0 ].shape () [1 ] * context_length_);
221216
222- auto out_grad_mat = std::make_shared<typename MatrixT<Device>::type>(
223- outputs[0 ].getData (), outputs[0 ].dims_ [0 ], outputs[0 ].dims_ [1 ]);
217+ auto out_grad_mat = outputs[0 ].matrix <Device>();
224218 auto in_grad_mat =
225- !inputs[0 ].getData ()
226- ? nullptr
227- : std::make_shared<typename MatrixT<Device>::type>(
228- inputs[0 ].getData (), inputs[0 ].dims_ [0 ], inputs[0 ].dims_ [1 ]);
229- auto w_grad_mat =
230- !inputs[1 ].getData ()
231- ? nullptr
232- : std::make_shared<typename MatrixT<Device>::type>(
233- inputs[1 ].getData (), inputs[1 ].dims_ [0 ], inputs[1 ].dims_ [1 ]);
234- typename SequenceT<Device>::type seq_vec (
235- inputs[2 ].dims_ [0 ], reinterpret_cast <int *>(inputs[2 ].getData ()));
236-
237- ContextProjectionBackward<Device>(out_grad_mat.get (),
238- in_grad_mat ? in_grad_mat.get () : nullptr ,
239- w_grad_mat ? w_grad_mat.get () : nullptr ,
219+ !inputs[0 ].data () ? typename Tensor<real, Device>::Matrix (nullptr , 0 , 0 )
220+ : inputs[0 ].matrix <Device>();
221+ auto w_grad_mat = !inputs[1 ].data ()
222+ ? typename Tensor<real, Device>::Matrix (nullptr , 0 , 0 )
223+ : inputs[1 ].matrix <Device>();
224+ auto seq_vec = inputs[2 ].vector <int , Device>();
225+ ContextProjectionBackward<Device>(out_grad_mat,
226+ in_grad_mat,
227+ w_grad_mat,
240228 seq_vec,
241229 context_length_,
242230 context_start_,
@@ -253,6 +241,7 @@ class ContextProjectionBackwardFunc : public FunctionBase {
253241 size_t total_pad_;
254242};
255243
244+ #if 0
256245/**
257246 * \param inputs[0] input grad.
258247 * \param inputs[1] input sequence.
@@ -272,6 +261,7 @@ class ContextProjectionBackwardDataFunc : public FunctionBase {
272261 CHECK_EQ(2, inputs.size());
273262 CHECK_EQ(1, outputs.size());
274263 CHECK_EQ(0, inouts.size());
264+
275265 CHECK(inputs[0].getData() && outputs[0].getData() && inputs[1].getData());
276266 CHECK_EQ(outputs[0].dims_.size(), 2);
277267 CHECK_EQ(inputs[0].dims_.size(), 2);
@@ -349,6 +339,7 @@ class ContextProjectionBackwardWeightFunc : public FunctionBase {
349339 size_t begin_pad_;
350340 size_t total_pad_;
351341};
342+ #endif
352343
353344REGISTER_TYPED_FUNC (ContextProjectionForward,
354345 CPU,
@@ -363,11 +354,13 @@ REGISTER_TYPED_FUNC(ContextProjectionForward,
363354REGISTER_TYPED_FUNC (ContextProjectionBackward,
364355 GPU,
365356 ContextProjectionBackwardFunc);
357+ #if 0
366358REGISTER_TYPED_FUNC(ContextProjectionBackwardData,
367359 GPU,
368360 ContextProjectionBackwardDataFunc);
369361REGISTER_TYPED_FUNC(ContextProjectionBackwardWeight,
370362 GPU,
371363 ContextProjectionBackwardWeightFunc);
372364#endif
365+ #endif
373366} // namespace paddle
0 commit comments