@@ -22,6 +22,11 @@ limitations under the License. */
2222
2323namespace paddle {
2424
25+ /* *
26+ * Function Configuration.
27+ * The argument type of Function::init.
28+ * Follow-up will consider moving this data structure to Proto inside.
29+ */
2530class FuncConfig {
2631public:
2732 union value {
@@ -41,6 +46,43 @@ class FuncConfig {
4146 std::map<std::string, value> valueMap_;
4247};
4348
49+ /* *
50+ * Argument type for Function::calc().
51+ * A BufferArgs contains a set of BufferArg,
52+ * because Function can have multiple inputs, outputs and inouts.
53+ */
54+ class BufferArgs {
55+ public:
56+ BufferArgs () {}
57+ size_t size () const { return args_.size (); }
58+
59+ // add argument into BufferArgss
60+ template <typename Tensor>
61+ void addArg (const Tensor& arg) {
62+ args_.push_back (std::make_shared<BufferArg>(arg));
63+ }
64+
65+ void addArg (const Matrix& arg, const TensorShape& shape);
66+
67+ void addArg (const CpuSparseMatrix& arg);
68+ void addArg (const GpuSparseMatrix& arg);
69+
70+ // get argument
71+ const BufferArg& operator [](size_t num) const {
72+ CHECK_LT (num, args_.size ());
73+ return *args_[num];
74+ }
75+
76+ private:
77+ std::vector<BufferArgPtr> args_;
78+ };
79+
80+ /* *
81+ * Base class for Function.
82+ * The basic Function implementation requires override init and calc interfaces.
83+ * Need to pay attention to the inouts argument. For the input argument
84+ * that will be modified, it needs to be passed through inouts.
85+ */
4486class FunctionBase {
4587public:
4688 virtual ~FunctionBase () {}
0 commit comments