Skip to content

Commit a980b83

Browse files
committed
Fix RNN unittest bugs.
* The DataProvider should be INCREF every time.
1 parent fefb3c1 commit a980b83

File tree

2 files changed

+6
-18
lines changed

2 files changed

+6
-18
lines changed

paddle/gserver/dataproviders/PyDataProvider2.cpp

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -252,19 +252,9 @@ class PyDataProvider2 : public DataProvider {
252252
// only for instance will make python reference-count error.
253253
//
254254
// So here, we increase reference count manually.
255-
if (gModuleClsPtrs_.find((uintptr_t)module.get()) !=
256-
gModuleClsPtrs_.end()) {
257-
// Multi instance use same module
258-
Py_XINCREF(module.get());
259-
Py_XINCREF(moduleDict.get());
260-
} else {
261-
gModuleClsPtrs_.insert((uintptr_t)module.get());
262-
}
263-
if (gModuleClsPtrs_.find((uintptr_t)cls.get()) != gModuleClsPtrs_.end()) {
264-
Py_XINCREF(cls.get());
265-
} else {
266-
gModuleClsPtrs_.insert((uintptr_t)cls.get());
267-
}
255+
Py_XINCREF(module.get());
256+
Py_XINCREF(moduleDict.get());
257+
Py_XINCREF(cls.get());
268258

269259
PyObjectPtr fileListInPy = loadPyFileLists(fileListName);
270260
PyDict_SetItemString(kwargs.get(), "file_list", fileListInPy.get());
@@ -471,7 +461,6 @@ class PyDataProvider2 : public DataProvider {
471461
std::vector<std::string> fileLists_;
472462
std::vector<SlotHeader> headers_;
473463
static PyObjectPtr zeroTuple_;
474-
static std::unordered_set<uintptr_t> gModuleClsPtrs_;
475464

476465
class PositionRandom {
477466
public:
@@ -671,7 +660,6 @@ class PyDataProvider2 : public DataProvider {
671660
}
672661
};
673662

674-
std::unordered_set<uintptr_t> PyDataProvider2::gModuleClsPtrs_;
675663
PyObjectPtr PyDataProvider2::zeroTuple_(PyTuple_New(0));
676664

677665
REGISTER_DATA_PROVIDER_EX(py2, PyDataProvider2);

paddle/gserver/tests/test_RecurrentGradientMachine.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ TEST(RecurrentGradientMachine, HasSubSequence) {
127127
}
128128
}
129129

130-
TEST(RecurrentGradientMachine, DISABLED_rnn) {
130+
TEST(RecurrentGradientMachine, rnn) {
131131
for (bool useGpu : {false, true}) {
132132
test("gserver/tests/sequence_rnn.conf",
133133
"gserver/tests/sequence_nest_rnn.conf",
@@ -136,7 +136,7 @@ TEST(RecurrentGradientMachine, DISABLED_rnn) {
136136
}
137137
}
138138

139-
TEST(RecurrentGradientMachine, DISABLED_rnn_multi_input) {
139+
TEST(RecurrentGradientMachine, rnn_multi_input) {
140140
for (bool useGpu : {false, true}) {
141141
test("gserver/tests/sequence_rnn_multi_input.conf",
142142
"gserver/tests/sequence_nest_rnn_multi_input.conf",
@@ -145,7 +145,7 @@ TEST(RecurrentGradientMachine, DISABLED_rnn_multi_input) {
145145
}
146146
}
147147

148-
TEST(RecurrentGradientMachine, DISABLED_rnn_multi_unequalength_input) {
148+
TEST(RecurrentGradientMachine, rnn_multi_unequalength_input) {
149149
for (bool useGpu : {false, true}) {
150150
test("gserver/tests/sequence_rnn_multi_unequalength_inputs.py",
151151
"gserver/tests/sequence_nest_rnn_multi_unequalength_inputs.py",

0 commit comments

Comments
 (0)