@@ -375,15 +375,6 @@ def test_insert(scalars_dfs, loc, column, value, allow_duplicates):
375375 pd .testing .assert_frame_equal (bf_df .to_pandas (), pd_df , check_dtype = False )
376376
377377
378- def test_where_series_cond (scalars_df_index , scalars_pandas_df_index ):
379- # Condition is dataframe, other is None (as default).
380- cond_bf = scalars_df_index ["int64_col" ] > 0
381- cond_pd = scalars_pandas_df_index ["int64_col" ] > 0
382- bf_result = scalars_df_index .where (cond_bf ).to_pandas ()
383- pd_result = scalars_pandas_df_index .where (cond_pd )
384- pandas .testing .assert_frame_equal (bf_result , pd_result )
385-
386-
387378def test_mask_series_cond (scalars_df_index , scalars_pandas_df_index ):
388379 cond_bf = scalars_df_index ["int64_col" ] > 0
389380 cond_pd = scalars_pandas_df_index ["int64_col" ] > 0
@@ -395,8 +386,102 @@ def test_mask_series_cond(scalars_df_index, scalars_pandas_df_index):
395386 pandas .testing .assert_frame_equal (bf_result , pd_result )
396387
397388
398- def test_where_series_multi_index (scalars_df_index , scalars_pandas_df_index ):
399- # Test when a dataframe has multi-index or multi-columns.
389+ def test_where_multi_index (scalars_df_index , scalars_pandas_df_index ):
390+ columns = ["int64_col" , "float64_col" ]
391+
392+ # Prepre the multi-index.
393+ index = pd .MultiIndex .from_tuples (
394+ [
395+ (0 , "a" ),
396+ (1 , "b" ),
397+ (2 , "c" ),
398+ (0 , "d" ),
399+ (1 , "e" ),
400+ (2 , "f" ),
401+ (0 , "g" ),
402+ (1 , "h" ),
403+ (2 , "i" ),
404+ ],
405+ names = ["A" , "B" ],
406+ )
407+
408+ # Create multi-index dataframe.
409+ dataframe_bf = bpd .DataFrame (
410+ scalars_df_index [columns ].values ,
411+ index = index ,
412+ columns = scalars_df_index [columns ].columns ,
413+ )
414+ dataframe_pd = pd .DataFrame (
415+ scalars_pandas_df_index [columns ].values ,
416+ index = index ,
417+ columns = scalars_pandas_df_index [columns ].columns ,
418+ )
419+ dataframe_bf .columns .name = "test_name"
420+ dataframe_pd .columns .name = "test_name"
421+
422+ # Test1: when condition is series and other is None.
423+ series_cond_bf = dataframe_bf ["int64_col" ] > 0
424+ series_cond_pd = dataframe_pd ["int64_col" ] > 0
425+
426+ bf_result = dataframe_bf .where (series_cond_bf ).to_pandas ()
427+ pd_result = dataframe_pd .where (series_cond_pd )
428+ pandas .testing .assert_frame_equal (
429+ bf_result ,
430+ pd_result ,
431+ check_index_type = False ,
432+ check_dtype = False ,
433+ )
434+ # Assert the index is still MultiIndex after the operation.
435+ assert isinstance (bf_result .index , pd .MultiIndex ), "Expected a MultiIndex"
436+ assert isinstance (pd_result .index , pd .MultiIndex ), "Expected a MultiIndex"
437+
438+ # Test2: when condition is series and other is dataframe.
439+ series_cond_bf = dataframe_bf ["int64_col" ] > 1000.0
440+ series_cond_pd = dataframe_pd ["int64_col" ] > 1000.0
441+ dataframe_other_bf = dataframe_bf * 100.0
442+ dataframe_other_pd = dataframe_pd * 100.0
443+
444+ bf_result = dataframe_bf .where (series_cond_bf , dataframe_other_bf ).to_pandas ()
445+ pd_result = dataframe_pd .where (series_cond_pd , dataframe_other_pd )
446+ pandas .testing .assert_frame_equal (
447+ bf_result ,
448+ pd_result ,
449+ check_index_type = False ,
450+ check_dtype = False ,
451+ )
452+
453+ # Test3: when condition is dataframe and other is a constant.
454+ dataframe_cond_bf = dataframe_bf > 0
455+ dataframe_cond_pd = dataframe_pd > 0
456+ other = 0
457+
458+ bf_result = dataframe_bf .where (dataframe_cond_bf , other ).to_pandas ()
459+ pd_result = dataframe_pd .where (dataframe_cond_pd , other )
460+ pandas .testing .assert_frame_equal (
461+ bf_result ,
462+ pd_result ,
463+ check_index_type = False ,
464+ check_dtype = False ,
465+ )
466+
467+ # Test4: when condition is dataframe and other is dataframe.
468+ dataframe_cond_bf = dataframe_bf < 1000.0
469+ dataframe_cond_pd = dataframe_pd < 1000.0
470+ dataframe_other_bf = dataframe_bf * - 1.0
471+ dataframe_other_pd = dataframe_pd * - 1.0
472+
473+ bf_result = dataframe_bf .where (dataframe_cond_bf , dataframe_other_bf ).to_pandas ()
474+ pd_result = dataframe_pd .where (dataframe_cond_pd , dataframe_other_pd )
475+ pandas .testing .assert_frame_equal (
476+ bf_result ,
477+ pd_result ,
478+ check_index_type = False ,
479+ check_dtype = False ,
480+ )
481+
482+
483+ def test_where_series_multi_column (scalars_df_index , scalars_pandas_df_index ):
484+ # Test when a dataframe has multi-columns.
400485 columns = ["int64_col" , "float64_col" ]
401486 dataframe_bf = scalars_df_index [columns ]
402487
@@ -409,10 +494,19 @@ def test_where_series_multi_index(scalars_df_index, scalars_pandas_df_index):
409494 dataframe_bf .where (cond_bf ).to_pandas ()
410495 assert (
411496 str (context .value )
412- == "The dataframe.where() method does not support multi-index and/or multi- column."
497+ == "The dataframe.where() method does not support multi-column."
413498 )
414499
415500
501+ def test_where_series_cond (scalars_df_index , scalars_pandas_df_index ):
502+ # Condition is dataframe, other is None (as default).
503+ cond_bf = scalars_df_index ["int64_col" ] > 0
504+ cond_pd = scalars_pandas_df_index ["int64_col" ] > 0
505+ bf_result = scalars_df_index .where (cond_bf ).to_pandas ()
506+ pd_result = scalars_pandas_df_index .where (cond_pd )
507+ pandas .testing .assert_frame_equal (bf_result , pd_result )
508+
509+
416510def test_where_series_cond_const_other (scalars_df_index , scalars_pandas_df_index ):
417511 # Condition is a series, other is a constant.
418512 columns = ["int64_col" , "float64_col" ]
0 commit comments