@@ -57,58 +57,67 @@ typedef std::shared_ptr<BufferArg> BufferArgPtr;
5757 * output Buffer or added to the output Buffer is determined by the
5858 * argType_ property of the output BufferArg.
5959 */
60+
61+ // ArgType is only used by output BufferArg.
62+ // For input argument, argType_ is ignored.
63+ // For output argument, need to set the argType_ of the BufferArg.
64+ enum ArgType {
65+ UNSPECIFIED = 0 ,
66+ ASSIGN_TO = 1 ,
67+ ADD_TO = 2 ,
68+ };
6069class BufferArg {
6170public:
62- // ArgType is only used by output BufferArg.
63- // For input argument, argType_ is ignored.
64- // For output argument, need to set the argType_ of the BufferArg.
65- enum ArgType {
66- UNSPECIFIED = 0 ,
67- ASSIGN_TO = 1 ,
68- ADD_TO = 2 ,
69- };
70-
7171 void setArgType (ArgType argType) { argType_ = argType; }
7272
7373 ArgType getArgType () const { return argType_; }
7474
7575public:
76- BufferArg (void * buf, ValueType valueType, const TensorShape& shape)
77- : buf_(buf), valueType_(valueType), shape_(shape) {}
76+ BufferArg (void * buf,
77+ ValueType valueType,
78+ const TensorShape& shape,
79+ ArgType argType = UNSPECIFIED)
80+ : buf_(buf), valueType_(valueType), shape_(shape), argType_(argType) {}
7881
7982 BufferArg (void * buf, ValueType valueType)
8083 : buf_(buf), valueType_(valueType) {}
8184
82- BufferArg (const Matrix& matrix)
85+ BufferArg (const Matrix& matrix, ArgType argType = UNSPECIFIED )
8386 : buf_(
8487 const_cast <void *>(reinterpret_cast <const void *>(matrix.getData()))),
8588 valueType_ (DataType<real>::value),
86- shape_(2 ) {
89+ shape_(2 ),
90+ argType_(argType) {
8791 shape_.setDim (0 , matrix.getHeight ());
8892 shape_.setDim (1 , matrix.getWidth ());
8993 }
9094
91- BufferArg (const Matrix& matrix, const TensorShape& shape)
95+ BufferArg (const Matrix& matrix,
96+ const TensorShape& shape,
97+ ArgType argType = UNSPECIFIED)
9298 : buf_(
9399 const_cast <void *>(reinterpret_cast <const void *>(matrix.getData()))),
94100 valueType_(DataType<real>::value),
95- shape_(shape) {
101+ shape_(shape),
102+ argType_(argType) {
96103 CHECK_EQ (matrix.getElementCnt (), shape.getElements ());
97104 }
98105
99- BufferArg (const Vector& vector)
106+ BufferArg (const Vector& vector, ArgType argType = UNSPECIFIED )
100107 : buf_(
101108 const_cast <void *>(reinterpret_cast <const void *>(vector.getData()))),
102109 valueType_(DataType<real>::value),
103- shape_(1 ) {
110+ shape_(1 ),
111+ argType_(argType) {
104112 shape_.setDim (0 , vector.getSize ());
105113 }
106114
107- BufferArg (const IVector& vector)
115+ BufferArg (const IVector& vector, ArgType argType = UNSPECIFIED )
108116 : buf_(
109117 const_cast <void *>(reinterpret_cast <const void *>(vector.getData()))),
110118 valueType_(VALUE_TYPE_INT32),
111- shape_(1 ) {
119+ shape_(1 ),
120+ argType_(argType) {
112121 shape_.setDim (0 , vector.getSize ());
113122 }
114123
@@ -163,8 +172,10 @@ class BufferArg {
163172// if a < b then value_.buf_[a] < value_.buf_[b]
164173class SequenceIdArg : public BufferArg {
165174public:
166- SequenceIdArg (void * buf, const TensorShape& shape)
167- : BufferArg(buf, VALUE_TYPE_INT32, shape) {
175+ SequenceIdArg (void * buf,
176+ const TensorShape& shape,
177+ ArgType argType = UNSPECIFIED)
178+ : BufferArg(buf, VALUE_TYPE_INT32, shape, argType) {
168179 CHECK_EQ (shape_.ndims (), 1 );
169180 numSeqs_ = shape_[0 ] - 1 ;
170181 }
@@ -187,11 +198,15 @@ class SequenceArg : public BufferArg {
187198 SequenceArg (void * buf,
188199 ValueType valueType,
189200 const TensorShape& shape,
190- const SequenceIdArg& startPositions)
191- : BufferArg(buf, valueType, shape), startPositions_(startPositions) {}
201+ const SequenceIdArg& startPositions,
202+ ArgType argType = UNSPECIFIED)
203+ : BufferArg(buf, valueType, shape, argType),
204+ startPositions_ (startPositions) {}
192205
193- SequenceArg (const Matrix& matrix, const IVector& vector)
194- : BufferArg(matrix), startPositions_(vector) {}
206+ SequenceArg (const Matrix& matrix,
207+ const IVector& vector,
208+ ArgType argType = UNSPECIFIED)
209+ : BufferArg(matrix, argType), startPositions_(vector) {}
195210
196211 ~SequenceArg () {}
197212
@@ -214,8 +229,9 @@ class SparseMatrixArg : public BufferArg {
214229 const BufferArg& col,
215230 size_t nnz,
216231 SparseDataFormat format,
217- SparseDataType type)
218- : BufferArg(buf, valueType, shape),
232+ SparseDataType type,
233+ ArgType argType = UNSPECIFIED)
234+ : BufferArg(buf, valueType, shape, argType),
219235 row_ (row),
220236 col_(col),
221237 nnz_(nnz),
@@ -232,13 +248,13 @@ class SparseMatrixArg : public BufferArg {
232248 }
233249 }
234250
235- SparseMatrixArg (const CpuSparseMatrix& sparse)
236- : BufferArg(sparse),
251+ SparseMatrixArg (const CpuSparseMatrix& sparse, ArgType argType = UNSPECIFIED )
252+ : BufferArg(sparse, argType ),
237253 row_(reinterpret_cast <void *>(sparse.getRows()), VALUE_TYPE_INT32),
238254 col_(reinterpret_cast <void *>(sparse.getCols()), VALUE_TYPE_INT32) {}
239255
240- SparseMatrixArg (const GpuSparseMatrix& sparse)
241- : BufferArg(sparse),
256+ SparseMatrixArg (const GpuSparseMatrix& sparse, ArgType argType = UNSPECIFIED )
257+ : BufferArg(sparse, argType ),
242258 row_(reinterpret_cast <void *>(sparse.getRows()), VALUE_TYPE_INT32),
243259 col_(reinterpret_cast <void *>(sparse.getCols()), VALUE_TYPE_INT32) {}
244260
0 commit comments