@@ -15,16 +15,16 @@ limitations under the License. */
1515#ifndef PADDLE_NO_PYTHON
1616#include < gtest/gtest.h>
1717#include < fstream>
18- #include " paddle/utils/Util.h"
19- #include " paddle/utils/PythonUtil.h"
2018#include " paddle/gserver/dataproviders/DataProvider.h"
19+ #include " paddle/utils/PythonUtil.h"
20+ #include " paddle/utils/Util.h"
2121
2222P_DEFINE_string (train_list, " unittest.list" , " file list for unittest" );
2323
2424namespace paddle {
2525namespace unittest {
2626namespace pydp2 {
27- extern void setOnPoolFilledHook (const std::function<void (size_t )>& func);
27+ extern void setOnPoolFilledHook (const std::function<void (size_t )> & func);
2828extern void clearOnPoolFilledHook ();
2929
3030} // namespace pydp2
@@ -33,8 +33,8 @@ extern void clearOnPoolFilledHook();
3333
3434const paddle::real epsilon = 1e-5 ;
3535
36- static inline int64_t readDataBatch (paddle::DataBatch* batch,
37- const std::string& funcName,
36+ static inline int64_t readDataBatch (paddle::DataBatch * batch,
37+ const std::string & funcName,
3838 int64_t batchSize = 65535 ) {
3939 paddle::DataConfig config;
4040 config.set_type (" py2" );
@@ -143,7 +143,7 @@ TEST(PyDataProvider2, init_hook) {
143143 paddle::DataBatch batch;
144144 int64_t num = provider->getNextBatchInternal (100000 , &batch);
145145 ASSERT_EQ (num, 200 );
146- auto & mat = batch.getStreams ()[0 ].value ;
146+ auto & mat = batch.getStreams ()[0 ].value ;
147147 ASSERT_EQ ((size_t )mat->getWidth (), (size_t )20 );
148148 for (size_t i = 0 ; i < 200 ; ++i) {
149149 for (size_t j = 0 ; j < 20 ; ++j) {
@@ -170,7 +170,7 @@ TEST(PyDataProvider2, sparse_no_value_no_seq) {
170170 CHECK (csm != nullptr );
171171 for (int i = 0 ; i < 200 ; ++i) {
172172 CHECK_EQ (csm->getColNum (i), (size_t )10 );
173- int * cols = csm->getRowCols (i);
173+ int * cols = csm->getRowCols (i);
174174 for (int j = 0 ; j < 10 ; ++j) {
175175 CHECK_EQ (cols[j], (i + 1 ) * (j + 1 ));
176176 }
@@ -185,8 +185,8 @@ TEST(PyDataProvider2, sparse_value_no_seq) {
185185 CHECK (csm != nullptr );
186186 for (int i = 0 ; i < 200 ; ++i) {
187187 CHECK_EQ (csm->getColNum (i), (size_t )10 );
188- int * cols = csm->getRowCols (i);
189- real* dat = csm->getRowValues (i);
188+ int * cols = csm->getRowCols (i);
189+ real * dat = csm->getRowValues (i);
190190 for (int j = 0 ; j < 10 ; ++j) {
191191 EXPECT_EQ (cols[j], (i + 1 ) * (j + 1 ));
192192 EXPECT_EQ (dat[j], real (j) / real (i + 1 ));
@@ -197,7 +197,7 @@ TEST(PyDataProvider2, sparse_value_no_seq) {
197197TEST (PyDataProvider2, index_seq) {
198198 paddle::DataBatch batch;
199199 CHECK_EQ (readDataBatch (&batch, " test_index_seq" ), 200 );
200- auto & arg = batch.getStreams ()[0 ];
200+ auto & arg = batch.getStreams ()[0 ];
201201 CHECK_EQ ((int )arg.ids ->getSize (), (200 + 1 ) * 200 / 2 );
202202 size_t tmp = 0 ;
203203 for (size_t i = 0 ; i < 200 ; ++i) { // CHECK DATA CORRECT
@@ -219,7 +219,7 @@ TEST(PyDataProvider2, index_seq) {
219219TEST (PyDataProvider2, index_sub_seq) {
220220 paddle::DataBatch batch;
221221 ASSERT_EQ (readDataBatch (&batch, " test_index_sub_seq" ), 200 );
222- auto & arg = batch.getStreams ()[0 ];
222+ auto & arg = batch.getStreams ()[0 ];
223223 size_t tmp = 0 ;
224224 for (size_t i = 0 ; i < 200 ; ++i) {
225225 for (size_t j = 0 ; j < i + 1 ; ++j) {
@@ -268,7 +268,7 @@ TEST(PyDataProvider2, min_pool_size) {
268268 }
269269 });
270270 while (true ) {
271- size_t realBatchSize = provider->getNextBatchInternal (batchSize, &batch);
271+ int64_t realBatchSize = provider->getNextBatchInternal (batchSize, &batch);
272272 if (realBatchSize) {
273273 totalData -= realBatchSize;
274274 } else {
@@ -291,7 +291,7 @@ TEST(PyDataProvider2, can_over_batch_size) {
291291 provider->reset ();
292292 constexpr size_t batchSize = 100 ;
293293 while (true ) {
294- size_t realBatchSize = provider->getNextBatchInternal (batchSize, &batch);
294+ int64_t realBatchSize = provider->getNextBatchInternal (batchSize, &batch);
295295 if (realBatchSize) {
296296 CHECK_LE (realBatchSize, batchSize);
297297 } else {
@@ -317,12 +317,12 @@ TEST(PyDataProvider2, input_order) {
317317 provider->reset ();
318318 constexpr size_t batchSize = 100 ;
319319 while (true ) {
320- size_t realBatchSize = provider->getNextBatchInternal (batchSize, &batch);
320+ int64_t realBatchSize = provider->getNextBatchInternal (batchSize, &batch);
321321 if (!realBatchSize) {
322322 break ;
323323 }
324- ASSERT_EQ (batch.getStreams ().size (), ( size_t ) 2 );
325- for (size_t i = 0 ; i < realBatchSize; ++i) {
324+ ASSERT_EQ (batch.getStreams ().size (), static_cast < size_t >( 2 ) );
325+ for (int64_t i = 0 ; i < realBatchSize; ++i) {
326326 ASSERT_EQ (batch.getStream (0 ).ids ->getData ()[i], 0 );
327327 ASSERT_EQ (batch.getStream (1 ).ids ->getData ()[i], 1 );
328328 }
@@ -341,11 +341,11 @@ TEST(PyDataProvider2, test_check) {
341341 paddle::DataProvider::create (config, false ));
342342 provider->reset ();
343343 while (true ) {
344- size_t realBatchSize = provider->getNextBatchInternal (100 , &batch);
344+ int64_t realBatchSize = provider->getNextBatchInternal (100 , &batch);
345345 if (!realBatchSize) {
346346 break ;
347347 } else {
348- auto & ivec = batch.getStream (0 ).ids ;
348+ auto & ivec = batch.getStream (0 ).ids ;
349349 for (size_t i = 0 ; i < ivec->getSize (); ++i) {
350350 CHECK_LT (ivec->getData ()[i], 10 );
351351 }
@@ -370,7 +370,30 @@ TEST(PyDataProvider2, multiThread) {
370370 provider.reset ();
371371}
372372
373- int main (int argc, char ** argv) {
373+ TEST (PyDataProvider2, minPoolSizeWithCache) {
374+ paddle::DataConfig config;
375+ config.set_type (" py2" );
376+ config.set_files (FLAGS_train_list.c_str ());
377+ config.set_load_data_module (" test_PyDataProvider2" );
378+ config.set_load_data_object (" test_min_pool_size_with_cache" );
379+ config.set_async_load_data (true );
380+
381+ std::unique_ptr<paddle::DataProvider> provider (
382+ paddle::DataProvider::create (config, false ));
383+
384+ paddle::DataBatch batch;
385+
386+ for (int i = 0 ; i < 10 ; ++i) {
387+ provider->reset ();
388+ int64_t sum = 0 ;
389+ while (int64_t actualNum = provider->getNextBatch (100 , &batch)) {
390+ sum += actualNum;
391+ }
392+ ASSERT_EQ (1 << 20 , sum);
393+ }
394+ }
395+
396+ int main (int argc, char **argv) {
374397 testing::InitGoogleTest (&argc, argv);
375398 paddle::initMain (argc, argv);
376399 paddle::initPython (argc, argv);
0 commit comments