Skip to content

Commit c5c8051

Browse files
committed
add BufferArg
1 parent 0c4be7e commit c5c8051

File tree

4 files changed

+436
-0
lines changed

4 files changed

+436
-0
lines changed

paddle/function/BufferArg.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include <glog/logging.h>
16+
17+
#include "BufferArg.h"
18+
19+
namespace paddle {
20+
21+
const SequenceArg& BufferArg::sequence() const {
22+
// CHECK_EQ(bufferType_, TENSOR_SEQUENCE_DATA);
23+
return dynamic_cast<const SequenceArg&>(*this);
24+
}
25+
26+
const SparseMatrixArg& BufferArg::sparse() const {
27+
// CHECK_EQ(bufferType_, TENSOR_SPARSE);
28+
return dynamic_cast<const SparseMatrixArg&>(*this);
29+
}
30+
31+
void BufferArgs::addArg(const Matrix& arg, const TensorShape& shape) {
32+
args_.push_back(std::make_shared<BufferArg>(arg, shape));
33+
}
34+
35+
void BufferArgs::addArg(const CpuSparseMatrix& arg) {
36+
args_.push_back(std::make_shared<SparseMatrixArg>(arg));
37+
}
38+
39+
void BufferArgs::addArg(const GpuSparseMatrix& arg) {
40+
args_.push_back(std::make_shared<SparseMatrixArg>(arg));
41+
}
42+
43+
} // namespace paddle

paddle/function/BufferArg.h

Lines changed: 260 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,260 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include <glog/logging.h>
18+
19+
#include "TensorShape.h"
20+
#include "TensorType.h"
21+
#include "paddle/math/CpuSparseMatrix.h"
22+
#include "paddle/math/Matrix.h"
23+
#include "paddle/math/SparseMatrix.h"
24+
25+
namespace paddle {
26+
27+
enum BufferType {
28+
TENSOR_NORMAL = 0,
29+
TENSOR_SEQUENCE_ID = 1,
30+
TENSOR_SEQUENCE_DATA = 2,
31+
TENSOR_SPARSE = 3
32+
};
33+
34+
enum SparseDataType {
35+
SPARSE_NO_VALUE = 0, // do not need value pointer, all values are 1
36+
SPARSE_FLOAT_VALUE = 1
37+
};
38+
39+
enum SparseDataFormat { SPARSE_CSR_FORMAT = 0, SPARSE_CSC_FORMAT = 1 };
40+
41+
/**
42+
* BufferArg used as the argument type for Function.
43+
*/
44+
class BufferArg;
45+
class SequenceArg;
46+
class SparseMatrixArg;
47+
typedef std::shared_ptr<BufferArg> BufferArgPtr;
48+
49+
class BufferArgs {
50+
public:
51+
BufferArgs() {}
52+
size_t size() const { return args_.size(); }
53+
54+
// add argument into BufferArgss
55+
template <typename Tensor>
56+
void addArg(const Tensor& arg) {
57+
args_.push_back(std::make_shared<BufferArg>(arg));
58+
}
59+
60+
void addArg(const Matrix& arg, const TensorShape& shape);
61+
62+
void addArg(const CpuSparseMatrix& arg);
63+
void addArg(const GpuSparseMatrix& arg);
64+
65+
// get argument
66+
const BufferArg& operator[](size_t num) const {
67+
CHECK_LT(num, args_.size());
68+
return *args_[num];
69+
}
70+
71+
private:
72+
std::vector<BufferArgPtr> args_;
73+
};
74+
75+
// an array of arbitrary dimensions
76+
class BufferArg {
77+
public:
78+
BufferArg(void* buf, ValueType valueType, const TensorShape& shape)
79+
: buf_(buf), valueType_(valueType), shape_(shape) {}
80+
81+
BufferArg(void* buf, ValueType valueType)
82+
: buf_(buf), valueType_(valueType) {}
83+
84+
BufferArg(const Matrix& matrix)
85+
: buf_((void*)matrix.getData()),
86+
valueType_(DataType<real>::value),
87+
shape_(2) {
88+
shape_.setDim(0, matrix.getHeight());
89+
shape_.setDim(1, matrix.getWidth());
90+
}
91+
92+
BufferArg(const Matrix& matrix, const TensorShape& shape)
93+
: buf_((void*)matrix.getData()),
94+
valueType_(DataType<real>::value),
95+
shape_(shape) {
96+
CHECK_EQ(matrix.getElementCnt(), shape.getElements());
97+
}
98+
99+
BufferArg(const Vector& vector)
100+
: buf_((void*)vector.getData()),
101+
valueType_(DataType<real>::value),
102+
shape_(1) {
103+
shape_.setDim(0, vector.getSize());
104+
}
105+
106+
BufferArg(const IVector& vector)
107+
: buf_((void*)vector.getData()), valueType_(VALUE_TYPE_INT32), shape_(1) {
108+
shape_.setDim(0, vector.getSize());
109+
}
110+
111+
template <DeviceType DType>
112+
typename Tensor<real, DType>::Matrix matrix() const {
113+
CHECK(buf_);
114+
CHECK(valueType_ == DataType<real>::value);
115+
// CHECK(deviceType_ == DType);
116+
CHECK_EQ(2, shape_.ndims());
117+
return typename Tensor<real, DType>::Matrix(
118+
reinterpret_cast<real*>(buf_), shape_[0], shape_[1]);
119+
}
120+
121+
template <typename VType, DeviceType DType>
122+
typename Tensor<VType, DType>::Vector vector() const {
123+
CHECK(buf_);
124+
CHECK(valueType_ == DataType<VType>::value);
125+
// CHECK(deviceType_ == DType);
126+
CHECK_EQ(1, shape_.ndims());
127+
return typename Tensor<VType, DType>::Vector(
128+
shape_[0], reinterpret_cast<VType*>(buf_));
129+
}
130+
131+
virtual ~BufferArg() {}
132+
133+
template <typename T>
134+
T* data() const {
135+
return reinterpret_cast<T*>(buf_);
136+
}
137+
138+
void* data() const { return buf_; }
139+
ValueType valueType() const { return valueType_; }
140+
BufferType bufferType() const { return bufferType_; }
141+
const TensorShape& shape() const { return shape_; }
142+
143+
const SequenceArg& sequence() const;
144+
const SparseMatrixArg& sparse() const;
145+
146+
protected:
147+
void* buf_;
148+
ValueType valueType_;
149+
TensorShape shape_;
150+
BufferType bufferType_;
151+
// leading dimensions. The size is dims_.size()
152+
// Dims lds_;
153+
};
154+
155+
// sequence start positions in a mini-batch of sequences
156+
// shape_.ndims() == 1
157+
// valueType_ = int32
158+
// if a < b than value_.buf_[a] < value_.buf_[b]
159+
class SequenceIdArg : public BufferArg {
160+
public:
161+
SequenceIdArg(void* buf, const TensorShape& shape)
162+
: BufferArg(buf, VALUE_TYPE_INT32, shape) {
163+
CHECK_EQ(shape_.ndims(), 1);
164+
numSeqs_ = shape_[0] - 1;
165+
}
166+
167+
SequenceIdArg(const IVector& vector) : BufferArg(vector) {
168+
numSeqs_ = shape_[0] - 1;
169+
}
170+
171+
~SequenceIdArg() {}
172+
173+
size_t numSeqs() const { return numSeqs_; }
174+
175+
private:
176+
size_t numSeqs_;
177+
};
178+
179+
// sequence data
180+
class SequenceArg : public BufferArg {
181+
public:
182+
SequenceArg(void* buf,
183+
ValueType valueType,
184+
const TensorShape& shape,
185+
const SequenceIdArg& startPositions)
186+
: BufferArg(buf, valueType, shape), startPositions_(startPositions) {}
187+
188+
SequenceArg(const Matrix& matrix, const IVector& vector)
189+
: BufferArg(matrix), startPositions_(vector) {}
190+
191+
~SequenceArg() {}
192+
193+
void* getIdBuf() const { return startPositions_.data(); }
194+
size_t numSeqs() const { return startPositions_.numSeqs(); }
195+
196+
private:
197+
SequenceIdArg startPositions_;
198+
};
199+
200+
// sparse matrix
201+
// valueType_ == float or double
202+
// shape_.ndims() == 2
203+
class SparseMatrixArg : public BufferArg {
204+
public:
205+
SparseMatrixArg(void* buf,
206+
ValueType valueType,
207+
const TensorShape& shape,
208+
const BufferArg& row,
209+
const BufferArg& col,
210+
size_t nnz,
211+
SparseDataFormat format,
212+
SparseDataType type)
213+
: BufferArg(buf, valueType, shape),
214+
row_(row),
215+
col_(col),
216+
nnz_(nnz),
217+
format_(format),
218+
type_(type) {
219+
CHECK((valueType == VALUE_TYPE_FLOAT) || (valueType == VALUE_TYPE_DOUBLE));
220+
CHECK_EQ(shape_.ndims(), 2);
221+
CHECK_EQ(row_.shape().ndims(), 1);
222+
CHECK_EQ(col_.shape().ndims(), 1);
223+
if (format == SPARSE_CSR_FORMAT) {
224+
CHECK_EQ(nnz, col.shape()[0]);
225+
} else if (format == SPARSE_CSC_FORMAT) {
226+
CHECK_EQ(nnz, row.shape()[0]);
227+
}
228+
}
229+
230+
SparseMatrixArg(const CpuSparseMatrix& sparse)
231+
: BufferArg(sparse),
232+
row_((void*)sparse.getRows(), VALUE_TYPE_INT32),
233+
col_((void*)sparse.getCols(), VALUE_TYPE_INT32) {}
234+
235+
SparseMatrixArg(const GpuSparseMatrix& sparse)
236+
: BufferArg(sparse),
237+
row_((void*)sparse.getRows(), VALUE_TYPE_INT32),
238+
col_((void*)sparse.getCols(), VALUE_TYPE_INT32) {}
239+
240+
~SparseMatrixArg() {}
241+
242+
void* getRowBuf() const { return row_.data(); }
243+
244+
void* getColBuf() const { return col_.data(); }
245+
246+
size_t nnz() const { return nnz_; }
247+
248+
SparseDataFormat dataFormat() const { return format_; }
249+
250+
SparseDataType dataType() const { return type_; }
251+
252+
private:
253+
BufferArg row_;
254+
BufferArg col_;
255+
size_t nnz_;
256+
SparseDataFormat format_;
257+
SparseDataType type_;
258+
};
259+
260+
} // namespace paddle

0 commit comments

Comments
 (0)