Skip to content

Commit 82774db

Browse files
authored
Merge pull request #239 from hedaoyuan/tensor
Add TensorExpression
2 parents 5d26716 + abdcb8e commit 82774db

28 files changed

+4058
-198
lines changed

paddle/cuda/include/hl_matrix_type.cuh

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15-
1615
#ifndef HL_MATRIX_TYPE_CUH_
1716
#define HL_MATRIX_TYPE_CUH_
1817

1918
#include "hl_base.h"
2019

2120
#ifdef __CUDA_ARCH__
22-
// typedef void* vecType;
2321
#include <vector_types.h>
2422
#ifndef PADDLE_TYPE_DOUBLE
2523
typedef float4 vecType;
@@ -37,4 +35,10 @@ typedef __m128d vecType;
3735
#endif
3836
#endif
3937

40-
#endif /* HL_MATRIX_TYPE_CUH_ */
38+
#ifdef __CUDA_ARCH__
39+
#define INLINE __device__ inline
40+
#else
41+
#define INLINE inline
42+
#endif
43+
44+
#endif // HL_MATRIX_TYPE_CUH_
Lines changed: 334 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,334 @@
1+
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#ifndef HL_TENSOR_OPS_H_
16+
#define HL_TENSOR_OPS_H_
17+
18+
#include <cmath>
19+
#include "hl_matrix_type.cuh"
20+
21+
namespace hppl {
22+
namespace unary {
23+
24+
template <class T>
25+
class add_scale {
26+
private:
27+
const T p;
28+
29+
public:
30+
INLINE add_scale(const T s) : p(s) {}
31+
INLINE T operator()(const T a) const { return a + p; }
32+
};
33+
34+
template <class T>
35+
class sub_scale {
36+
private:
37+
const T p;
38+
39+
public:
40+
INLINE sub_scale(const T s) : p(s) {}
41+
INLINE T operator()(const T a) const { return a - p; }
42+
};
43+
44+
template <class T>
45+
class mul_scale {
46+
private:
47+
const T p;
48+
49+
public:
50+
INLINE mul_scale(const T s) : p(s) {}
51+
INLINE T operator()(const T a) const { return a * p; }
52+
};
53+
54+
template <class T>
55+
class div_scale {
56+
private:
57+
const T p;
58+
59+
public:
60+
INLINE div_scale(const T s) : p(s) {}
61+
INLINE T operator()(const T a) const { return a / p; }
62+
};
63+
64+
template <class T>
65+
class neg {
66+
public:
67+
INLINE T operator()(const T a) const { return -a; }
68+
};
69+
70+
template <class T>
71+
class exp_op {
72+
public:
73+
INLINE T operator()(const T a) const { return std::exp(a); }
74+
};
75+
76+
template <class T>
77+
class log_op {
78+
public:
79+
INLINE T operator()(const T a) const { return std::log(a); }
80+
};
81+
82+
template <class T>
83+
class sqrt_op {
84+
public:
85+
INLINE T operator()(const T a) const { return std::sqrt(a); }
86+
};
87+
88+
template <class T>
89+
class square {
90+
public:
91+
INLINE T operator()(const T a) const { return a * a; }
92+
};
93+
94+
template <class T>
95+
class reciprocal {
96+
public:
97+
INLINE T operator()(const T a) const { return T(1) / a; }
98+
};
99+
100+
template <class T>
101+
class abs {
102+
public:
103+
INLINE T operator()(const T a) const { return a > 0 ? a : -a; }
104+
};
105+
106+
template <class T>
107+
class sign {
108+
public:
109+
INLINE T operator()(const T a) const { return (a > 0) - (a < 0); }
110+
};
111+
112+
template <class T>
113+
class min {
114+
private:
115+
const T p;
116+
117+
public:
118+
INLINE min(const T s) : p(s) {}
119+
INLINE T operator()(const T a) const { return a > p ? p : a; }
120+
};
121+
122+
template <class T>
123+
class max {
124+
private:
125+
const T p;
126+
127+
public:
128+
INLINE max(const T s) : p(s) {}
129+
INLINE T operator()(const T a) const { return a < p ? p : a; }
130+
};
131+
132+
template <class T>
133+
class pow_op {
134+
private:
135+
const T p;
136+
137+
public:
138+
INLINE pow_op(const T s) : p(s) {}
139+
INLINE T operator()(const T a) const { return std::pow(a, p); }
140+
};
141+
142+
template <class T>
143+
class constant {
144+
private:
145+
const T p;
146+
147+
public:
148+
INLINE constant(const T s) : p(s) {}
149+
INLINE T operator()(int i) const { return p; }
150+
INLINE T operator()(int i, int j) const { return p; }
151+
};
152+
153+
template <class T>
154+
class cmp_eq {
155+
private:
156+
const T p;
157+
158+
public:
159+
INLINE cmp_eq(const T s) : p(s) {}
160+
INLINE bool operator()(const T a) const { return a == p; }
161+
};
162+
163+
template <class T>
164+
class cmp_ne {
165+
private:
166+
const T p;
167+
168+
public:
169+
INLINE cmp_ne(const T s) : p(s) {}
170+
INLINE bool operator()(const T a) const { return a != p; }
171+
};
172+
173+
template <class T>
174+
class cmp_le {
175+
private:
176+
const T p;
177+
178+
public:
179+
INLINE cmp_le(const T s) : p(s) {}
180+
INLINE bool operator()(const T a) const { return a <= p; }
181+
};
182+
183+
template <class T>
184+
class cmp_lt {
185+
private:
186+
const T p;
187+
188+
public:
189+
INLINE cmp_lt(const T s) : p(s) {}
190+
INLINE bool operator()(const T a) const { return a < p; }
191+
};
192+
193+
template <class T>
194+
class cmp_ge {
195+
private:
196+
const T p;
197+
198+
public:
199+
INLINE cmp_ge(const T s) : p(s) {}
200+
INLINE bool operator()(const T a) const { return a >= p; }
201+
};
202+
203+
template <class T>
204+
class cmp_gt {
205+
private:
206+
const T p;
207+
208+
public:
209+
INLINE cmp_gt(const T s) : p(s) {}
210+
INLINE bool operator()(const T a) const { return a > p; }
211+
};
212+
213+
template <class T>
214+
class and_op {
215+
private:
216+
const T p;
217+
218+
public:
219+
INLINE and_op(const T s) : p(s) {}
220+
INLINE bool operator()(const T a) const { return a && p; }
221+
};
222+
223+
template <class T>
224+
class or_op {
225+
private:
226+
const T p;
227+
228+
public:
229+
INLINE or_op(const T s) : p(s) {}
230+
INLINE bool operator()(const T a) const { return a || p; }
231+
};
232+
233+
} // namespace unary
234+
235+
namespace binary {
236+
template <class T>
237+
class add {
238+
public:
239+
INLINE T operator()(const T a, const T b) const { return a + b; }
240+
};
241+
242+
template <class T>
243+
class add_scale {
244+
private:
245+
const T p1;
246+
const T p2;
247+
248+
public:
249+
INLINE add_scale(const T s1, const T s2) : p1(s1), p2(s2) {}
250+
INLINE T operator()(const T a, const T b) const { return p1 * a + p2 * b; }
251+
};
252+
253+
template <class T>
254+
class sub {
255+
public:
256+
INLINE T operator()(const T a, const T b) const { return a - b; }
257+
};
258+
259+
template <class T>
260+
class mul {
261+
public:
262+
INLINE T operator()(const T a, const T b) const { return a * b; }
263+
};
264+
265+
template <class T>
266+
class div {
267+
public:
268+
INLINE T operator()(const T a, const T b) const { return a / b; }
269+
};
270+
271+
template <class T>
272+
class cmp_eq {
273+
public:
274+
INLINE bool operator()(const T a, const T b) const { return a == b; }
275+
};
276+
277+
template <class T>
278+
class cmp_ne {
279+
public:
280+
INLINE bool operator()(const T a, const T b) const { return a != b; }
281+
};
282+
283+
template <class T>
284+
class cmp_le {
285+
public:
286+
INLINE bool operator()(const T a, const T b) const { return a <= b; }
287+
};
288+
289+
template <class T>
290+
class cmp_lt {
291+
public:
292+
INLINE bool operator()(const T a, const T b) const { return a < b; }
293+
};
294+
295+
template <class T>
296+
class cmp_ge {
297+
public:
298+
INLINE bool operator()(const T a, const T b) const { return a >= b; }
299+
};
300+
301+
template <class T>
302+
class cmp_gt {
303+
public:
304+
INLINE bool operator()(const T a, const T b) const { return a > b; }
305+
};
306+
307+
template <class T>
308+
class and_op {
309+
public:
310+
INLINE bool operator()(const T a, const T b) const { return a && b; }
311+
};
312+
313+
template <class T>
314+
class or_op {
315+
public:
316+
INLINE bool operator()(const T a, const T b) const { return a || b; }
317+
};
318+
319+
template <class T>
320+
class min {
321+
public:
322+
INLINE T operator()(const T a, const T b) const { return a > b ? b : a; }
323+
};
324+
325+
template <class T>
326+
class max {
327+
public:
328+
INLINE T operator()(const T a, const T b) const { return a < b ? b : a; }
329+
};
330+
331+
} // namespace binary
332+
} // namespace hppl
333+
334+
#endif // HL_TENSOR_OPS_H_

paddle/gserver/activations/ActivationFunction.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ void forward(Argument& act) {
289289
useGpu(act.deviceId));
290290

291291
act.in->copyFrom(*act.value);
292-
act.value->abs(*act.value);
292+
act.value->abs2(*act.value);
293293
}
294294

295295
void backward(Argument& act) { act.grad->absDerivative(*act.in); }
@@ -311,7 +311,7 @@ void forward(Argument& act) {
311311
useGpu(act.deviceId));
312312

313313
act.in->copyFrom(*act.value);
314-
act.value->square(*act.value);
314+
act.value->square2(*act.value);
315315
}
316316

317317
void backward(Argument& act) { act.grad->squareDerivative(*act.in); }
@@ -324,7 +324,7 @@ END_DEFINE_ACTIVATION(square)
324324
* \f]
325325
*/
326326
BEGIN_DEFINE_ACTIVATION(exponential)
327-
void forward(Argument& act) { act.value->exp(*act.value); }
327+
void forward(Argument& act) { act.value->exp2(*act.value); }
328328

329329
void backward(Argument& act) { act.grad->expDerivative(*act.value); }
330330
END_DEFINE_ACTIVATION(exponential)
@@ -345,7 +345,7 @@ void forward(Argument& act) {
345345
useGpu(act.deviceId));
346346

347347
act.in->copyFrom(*act.value);
348-
act.value->log(*act.value);
348+
act.value->log2(*act.value);
349349
}
350350

351351
void backward(Argument& act) { act.grad->dotDiv(*act.grad, *act.in); }

0 commit comments

Comments
 (0)