Skip to content

Commit adc23f6

Browse files
authored
Merge pull request #752 from reyoung/feature/fix_data_loss_in_pydp2
Add unittest related #653
2 parents b6d036a + 1539335 commit adc23f6

File tree

2 files changed

+52
-19
lines changed

2 files changed

+52
-19
lines changed

paddle/gserver/tests/test_PyDataProvider2.cpp

Lines changed: 42 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,16 @@ limitations under the License. */
1515
#ifndef PADDLE_NO_PYTHON
1616
#include <gtest/gtest.h>
1717
#include <fstream>
18-
#include "paddle/utils/Util.h"
19-
#include "paddle/utils/PythonUtil.h"
2018
#include "paddle/gserver/dataproviders/DataProvider.h"
19+
#include "paddle/utils/PythonUtil.h"
20+
#include "paddle/utils/Util.h"
2121

2222
P_DEFINE_string(train_list, "unittest.list", "file list for unittest");
2323

2424
namespace paddle {
2525
namespace unittest {
2626
namespace pydp2 {
27-
extern void setOnPoolFilledHook(const std::function<void(size_t)>& func);
27+
extern void setOnPoolFilledHook(const std::function<void(size_t)> &func);
2828
extern void clearOnPoolFilledHook();
2929

3030
} // namespace pydp2
@@ -33,8 +33,8 @@ extern void clearOnPoolFilledHook();
3333

3434
const paddle::real epsilon = 1e-5;
3535

36-
static inline int64_t readDataBatch(paddle::DataBatch* batch,
37-
const std::string& funcName,
36+
static inline int64_t readDataBatch(paddle::DataBatch *batch,
37+
const std::string &funcName,
3838
int64_t batchSize = 65535) {
3939
paddle::DataConfig config;
4040
config.set_type("py2");
@@ -143,7 +143,7 @@ TEST(PyDataProvider2, init_hook) {
143143
paddle::DataBatch batch;
144144
int64_t num = provider->getNextBatchInternal(100000, &batch);
145145
ASSERT_EQ(num, 200);
146-
auto& mat = batch.getStreams()[0].value;
146+
auto &mat = batch.getStreams()[0].value;
147147
ASSERT_EQ((size_t)mat->getWidth(), (size_t)20);
148148
for (size_t i = 0; i < 200; ++i) {
149149
for (size_t j = 0; j < 20; ++j) {
@@ -170,7 +170,7 @@ TEST(PyDataProvider2, sparse_no_value_no_seq) {
170170
CHECK(csm != nullptr);
171171
for (int i = 0; i < 200; ++i) {
172172
CHECK_EQ(csm->getColNum(i), (size_t)10);
173-
int* cols = csm->getRowCols(i);
173+
int *cols = csm->getRowCols(i);
174174
for (int j = 0; j < 10; ++j) {
175175
CHECK_EQ(cols[j], (i + 1) * (j + 1));
176176
}
@@ -185,8 +185,8 @@ TEST(PyDataProvider2, sparse_value_no_seq) {
185185
CHECK(csm != nullptr);
186186
for (int i = 0; i < 200; ++i) {
187187
CHECK_EQ(csm->getColNum(i), (size_t)10);
188-
int* cols = csm->getRowCols(i);
189-
real* dat = csm->getRowValues(i);
188+
int *cols = csm->getRowCols(i);
189+
real *dat = csm->getRowValues(i);
190190
for (int j = 0; j < 10; ++j) {
191191
EXPECT_EQ(cols[j], (i + 1) * (j + 1));
192192
EXPECT_EQ(dat[j], real(j) / real(i + 1));
@@ -197,7 +197,7 @@ TEST(PyDataProvider2, sparse_value_no_seq) {
197197
TEST(PyDataProvider2, index_seq) {
198198
paddle::DataBatch batch;
199199
CHECK_EQ(readDataBatch(&batch, "test_index_seq"), 200);
200-
auto& arg = batch.getStreams()[0];
200+
auto &arg = batch.getStreams()[0];
201201
CHECK_EQ((int)arg.ids->getSize(), (200 + 1) * 200 / 2);
202202
size_t tmp = 0;
203203
for (size_t i = 0; i < 200; ++i) { // CHECK DATA CORRECT
@@ -219,7 +219,7 @@ TEST(PyDataProvider2, index_seq) {
219219
TEST(PyDataProvider2, index_sub_seq) {
220220
paddle::DataBatch batch;
221221
ASSERT_EQ(readDataBatch(&batch, "test_index_sub_seq"), 200);
222-
auto& arg = batch.getStreams()[0];
222+
auto &arg = batch.getStreams()[0];
223223
size_t tmp = 0;
224224
for (size_t i = 0; i < 200; ++i) {
225225
for (size_t j = 0; j < i + 1; ++j) {
@@ -268,7 +268,7 @@ TEST(PyDataProvider2, min_pool_size) {
268268
}
269269
});
270270
while (true) {
271-
size_t realBatchSize = provider->getNextBatchInternal(batchSize, &batch);
271+
int64_t realBatchSize = provider->getNextBatchInternal(batchSize, &batch);
272272
if (realBatchSize) {
273273
totalData -= realBatchSize;
274274
} else {
@@ -291,7 +291,7 @@ TEST(PyDataProvider2, can_over_batch_size) {
291291
provider->reset();
292292
constexpr size_t batchSize = 100;
293293
while (true) {
294-
size_t realBatchSize = provider->getNextBatchInternal(batchSize, &batch);
294+
int64_t realBatchSize = provider->getNextBatchInternal(batchSize, &batch);
295295
if (realBatchSize) {
296296
CHECK_LE(realBatchSize, batchSize);
297297
} else {
@@ -317,12 +317,12 @@ TEST(PyDataProvider2, input_order) {
317317
provider->reset();
318318
constexpr size_t batchSize = 100;
319319
while (true) {
320-
size_t realBatchSize = provider->getNextBatchInternal(batchSize, &batch);
320+
int64_t realBatchSize = provider->getNextBatchInternal(batchSize, &batch);
321321
if (!realBatchSize) {
322322
break;
323323
}
324-
ASSERT_EQ(batch.getStreams().size(), (size_t)2);
325-
for (size_t i = 0; i < realBatchSize; ++i) {
324+
ASSERT_EQ(batch.getStreams().size(), static_cast<size_t>(2));
325+
for (int64_t i = 0; i < realBatchSize; ++i) {
326326
ASSERT_EQ(batch.getStream(0).ids->getData()[i], 0);
327327
ASSERT_EQ(batch.getStream(1).ids->getData()[i], 1);
328328
}
@@ -341,11 +341,11 @@ TEST(PyDataProvider2, test_check) {
341341
paddle::DataProvider::create(config, false));
342342
provider->reset();
343343
while (true) {
344-
size_t realBatchSize = provider->getNextBatchInternal(100, &batch);
344+
int64_t realBatchSize = provider->getNextBatchInternal(100, &batch);
345345
if (!realBatchSize) {
346346
break;
347347
} else {
348-
auto& ivec = batch.getStream(0).ids;
348+
auto &ivec = batch.getStream(0).ids;
349349
for (size_t i = 0; i < ivec->getSize(); ++i) {
350350
CHECK_LT(ivec->getData()[i], 10);
351351
}
@@ -370,7 +370,30 @@ TEST(PyDataProvider2, multiThread) {
370370
provider.reset();
371371
}
372372

373-
int main(int argc, char** argv) {
373+
TEST(PyDataProvider2, minPoolSizeWithCache) {
374+
paddle::DataConfig config;
375+
config.set_type("py2");
376+
config.set_files(FLAGS_train_list.c_str());
377+
config.set_load_data_module("test_PyDataProvider2");
378+
config.set_load_data_object("test_min_pool_size_with_cache");
379+
config.set_async_load_data(true);
380+
381+
std::unique_ptr<paddle::DataProvider> provider(
382+
paddle::DataProvider::create(config, false));
383+
384+
paddle::DataBatch batch;
385+
386+
for (int i = 0; i < 10; ++i) {
387+
provider->reset();
388+
int64_t sum = 0;
389+
while (int64_t actualNum = provider->getNextBatch(100, &batch)) {
390+
sum += actualNum;
391+
}
392+
ASSERT_EQ(1 << 20, sum);
393+
}
394+
}
395+
396+
int main(int argc, char **argv) {
374397
testing::InitGoogleTest(&argc, argv);
375398
paddle::initMain(argc, argv);
376399
paddle::initPython(argc, argv);

paddle/gserver/tests/test_PyDataProvider2.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,3 +111,13 @@ def test_check(settings, filename):
111111
if i < 10:
112112
yield_good_value = True
113113
yield i
114+
115+
116+
@provider(
117+
input_types=[index_slot(10)],
118+
min_pool_size=1000,
119+
cache=CacheType.CACHE_PASS_IN_MEM, )
120+
def test_min_pool_size_with_cache(settings, filename):
121+
import random
122+
for _ in xrange(2**20):
123+
yield random.randint(0, 9)

0 commit comments

Comments
 (0)