@@ -38,16 +38,40 @@ enum SparseDataType {
3838
3939enum SparseDataFormat { SPARSE_CSR_FORMAT = 0 , SPARSE_CSC_FORMAT = 1 };
4040
41- /* *
42- * BufferArg used as the argument type for Function.
43- */
4441class BufferArg ;
4542class SequenceArg ;
4643class SparseMatrixArg ;
4744typedef std::shared_ptr<BufferArg> BufferArgPtr;
4845
49- // an array of arbitrary dimensions
46+ /* *
47+ * \brief BufferArg used as the argument type of Function.
48+ *
49+ * The arguments of the Paddle Function have four Buffer types.
50+ * 1. BufferArg for a dense Buffer of any dimension.
51+ * 2. SequenceIdArg for a Buffer of sequence start positions.
52+ * 3. SequenceArg for a Buffer of sequence data.
53+ * 4. SparseMatrixArg for a Buffer of sparse matrix.
54+ *
55+ * There is an ArgType property for the BufferArg used as Function Output.
56+ * Whether the result of the Function calculation is assigned to the
57+ * output Buffer or added to the output Buffer is determined by the
58+ * argType_ property of the output BufferArg.
59+ */
5060class BufferArg {
61+ public:
62+ // ArgType is only used by output BufferArg.
63+ // For input argument, argType_ is ignored.
64+ // For output argument, need to set the argType_ of the BufferArg.
65+ enum ArgType {
66+ UNSPECIFIED = 0 ,
67+ ASSIGN_TO = 1 ,
68+ ADD_TO = 2 ,
69+ };
70+
71+ void setArgType (ArgType argType) { argType_ = argType; }
72+
73+ ArgType getArgType () const { return argType_; }
74+
5175public:
5276 BufferArg (void * buf, ValueType valueType, const TensorShape& shape)
5377 : buf_(buf), valueType_(valueType), shape_(shape) {}
@@ -56,29 +80,33 @@ class BufferArg {
5680 : buf_(buf), valueType_(valueType) {}
5781
5882 BufferArg (const Matrix& matrix)
59- : buf_(reinterpret_cast <void *>(matrix.getData())),
83+ : buf_(
84+ const_cast <void *>(reinterpret_cast <const void *>(matrix.getData()))),
6085 valueType_ (DataType<real>::value),
6186 shape_(2 ) {
6287 shape_.setDim (0 , matrix.getHeight ());
6388 shape_.setDim (1 , matrix.getWidth ());
6489 }
6590
6691 BufferArg (const Matrix& matrix, const TensorShape& shape)
67- : buf_(reinterpret_cast <void *>(matrix.getData())),
92+ : buf_(
93+ const_cast <void *>(reinterpret_cast <const void *>(matrix.getData()))),
6894 valueType_(DataType<real>::value),
6995 shape_(shape) {
7096 CHECK_EQ (matrix.getElementCnt (), shape.getElements ());
7197 }
7298
7399 BufferArg (const Vector& vector)
74- : buf_(reinterpret_cast <void *>(vector.getData())),
100+ : buf_(
101+ const_cast <void *>(reinterpret_cast <const void *>(vector.getData()))),
75102 valueType_(DataType<real>::value),
76103 shape_(1 ) {
77104 shape_.setDim (0 , vector.getSize ());
78105 }
79106
80107 BufferArg (const IVector& vector)
81- : buf_(reinterpret_cast <void *>(vector.getData())),
108+ : buf_(
109+ const_cast <void *>(reinterpret_cast <const void *>(vector.getData()))),
82110 valueType_(VALUE_TYPE_INT32),
83111 shape_(1 ) {
84112 shape_.setDim (0 , vector.getSize ());
@@ -124,6 +152,7 @@ class BufferArg {
124152 ValueType valueType_;
125153 TensorShape shape_;
126154 BufferType bufferType_;
155+ ArgType argType_ = UNSPECIFIED;
127156 // leading dimensions. The size is dims_.size()
128157 // Dims lds_;
129158};
0 commit comments