Skip to content

Commit d2da861

Browse files
committed
feat: support multi index for dataframe where
1 parent 6454aff commit d2da861

File tree

2 files changed

+108
-14
lines changed

2 files changed

+108
-14
lines changed

bigframes/dataframe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2741,9 +2741,9 @@ def where(self, cond, other=None):
27412741
if isinstance(other, bigframes.series.Series):
27422742
raise ValueError("Seires is not a supported replacement type!")
27432743

2744-
if self.columns.nlevels > 1 or self.index.nlevels > 1:
2744+
if self.columns.nlevels > 1:
27452745
raise NotImplementedError(
2746-
"The dataframe.where() method does not support multi-index and/or multi-column."
2746+
"The dataframe.where() method does not support multi-column."
27472747
)
27482748

27492749
aligned_block, (_, _) = self._block.join(cond._block, how="left")

tests/system/small/test_dataframe.py

Lines changed: 106 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -375,15 +375,6 @@ def test_insert(scalars_dfs, loc, column, value, allow_duplicates):
375375
pd.testing.assert_frame_equal(bf_df.to_pandas(), pd_df, check_dtype=False)
376376

377377

378-
def test_where_series_cond(scalars_df_index, scalars_pandas_df_index):
379-
# Condition is dataframe, other is None (as default).
380-
cond_bf = scalars_df_index["int64_col"] > 0
381-
cond_pd = scalars_pandas_df_index["int64_col"] > 0
382-
bf_result = scalars_df_index.where(cond_bf).to_pandas()
383-
pd_result = scalars_pandas_df_index.where(cond_pd)
384-
pandas.testing.assert_frame_equal(bf_result, pd_result)
385-
386-
387378
def test_mask_series_cond(scalars_df_index, scalars_pandas_df_index):
388379
cond_bf = scalars_df_index["int64_col"] > 0
389380
cond_pd = scalars_pandas_df_index["int64_col"] > 0
@@ -395,8 +386,102 @@ def test_mask_series_cond(scalars_df_index, scalars_pandas_df_index):
395386
pandas.testing.assert_frame_equal(bf_result, pd_result)
396387

397388

398-
def test_where_series_multi_index(scalars_df_index, scalars_pandas_df_index):
399-
# Test when a dataframe has multi-index or multi-columns.
389+
def test_where_multi_index(scalars_df_index, scalars_pandas_df_index):
390+
columns = ["int64_col", "float64_col"]
391+
392+
# Prepre the multi-index.
393+
index = pd.MultiIndex.from_tuples(
394+
[
395+
(0, "a"),
396+
(1, "b"),
397+
(2, "c"),
398+
(0, "d"),
399+
(1, "e"),
400+
(2, "f"),
401+
(0, "g"),
402+
(1, "h"),
403+
(2, "i"),
404+
],
405+
names=["A", "B"],
406+
)
407+
408+
# Create multi-index dataframe.
409+
dataframe_bf = bpd.DataFrame(
410+
scalars_df_index[columns].values,
411+
index=index,
412+
columns=scalars_df_index[columns].columns,
413+
)
414+
dataframe_pd = pd.DataFrame(
415+
scalars_pandas_df_index[columns].values,
416+
index=index,
417+
columns=scalars_pandas_df_index[columns].columns,
418+
)
419+
dataframe_bf.columns.name = "test_name"
420+
dataframe_pd.columns.name = "test_name"
421+
422+
# Test1: when condition is series and other is None.
423+
series_cond_bf = dataframe_bf["int64_col"] > 0
424+
series_cond_pd = dataframe_pd["int64_col"] > 0
425+
426+
bf_result = dataframe_bf.where(series_cond_bf).to_pandas()
427+
pd_result = dataframe_pd.where(series_cond_pd)
428+
pandas.testing.assert_frame_equal(
429+
bf_result,
430+
pd_result,
431+
check_index_type=False,
432+
check_dtype=False,
433+
)
434+
# Assert the index is still MultiIndex after the operation.
435+
assert isinstance(bf_result.index, pd.MultiIndex), "Expected a MultiIndex"
436+
assert isinstance(pd_result.index, pd.MultiIndex), "Expected a MultiIndex"
437+
438+
# Test2: when condition is series and other is dataframe.
439+
series_cond_bf = dataframe_bf["int64_col"] > 1000.0
440+
series_cond_pd = dataframe_pd["int64_col"] > 1000.0
441+
dataframe_other_bf = dataframe_bf * 100.0
442+
dataframe_other_pd = dataframe_pd * 100.0
443+
444+
bf_result = dataframe_bf.where(series_cond_bf, dataframe_other_bf).to_pandas()
445+
pd_result = dataframe_pd.where(series_cond_pd, dataframe_other_pd)
446+
pandas.testing.assert_frame_equal(
447+
bf_result,
448+
pd_result,
449+
check_index_type=False,
450+
check_dtype=False,
451+
)
452+
453+
# Test3: when condition is dataframe and other is a constant.
454+
dataframe_cond_bf = dataframe_bf > 0
455+
dataframe_cond_pd = dataframe_pd > 0
456+
other = 0
457+
458+
bf_result = dataframe_bf.where(dataframe_cond_bf, other).to_pandas()
459+
pd_result = dataframe_pd.where(dataframe_cond_pd, other)
460+
pandas.testing.assert_frame_equal(
461+
bf_result,
462+
pd_result,
463+
check_index_type=False,
464+
check_dtype=False,
465+
)
466+
467+
# Test4: when condition is dataframe and other is dataframe.
468+
dataframe_cond_bf = dataframe_bf < 1000.0
469+
dataframe_cond_pd = dataframe_pd < 1000.0
470+
dataframe_other_bf = dataframe_bf * -1.0
471+
dataframe_other_pd = dataframe_pd * -1.0
472+
473+
bf_result = dataframe_bf.where(dataframe_cond_bf, dataframe_other_bf).to_pandas()
474+
pd_result = dataframe_pd.where(dataframe_cond_pd, dataframe_other_pd)
475+
pandas.testing.assert_frame_equal(
476+
bf_result,
477+
pd_result,
478+
check_index_type=False,
479+
check_dtype=False,
480+
)
481+
482+
483+
def test_where_series_multi_column(scalars_df_index, scalars_pandas_df_index):
484+
# Test when a dataframe has multi-columns.
400485
columns = ["int64_col", "float64_col"]
401486
dataframe_bf = scalars_df_index[columns]
402487

@@ -409,10 +494,19 @@ def test_where_series_multi_index(scalars_df_index, scalars_pandas_df_index):
409494
dataframe_bf.where(cond_bf).to_pandas()
410495
assert (
411496
str(context.value)
412-
== "The dataframe.where() method does not support multi-index and/or multi-column."
497+
== "The dataframe.where() method does not support multi-column."
413498
)
414499

415500

501+
def test_where_series_cond(scalars_df_index, scalars_pandas_df_index):
502+
# Condition is dataframe, other is None (as default).
503+
cond_bf = scalars_df_index["int64_col"] > 0
504+
cond_pd = scalars_pandas_df_index["int64_col"] > 0
505+
bf_result = scalars_df_index.where(cond_bf).to_pandas()
506+
pd_result = scalars_pandas_df_index.where(cond_pd)
507+
pandas.testing.assert_frame_equal(bf_result, pd_result)
508+
509+
416510
def test_where_series_cond_const_other(scalars_df_index, scalars_pandas_df_index):
417511
# Condition is a series, other is a constant.
418512
columns = ["int64_col", "float64_col"]

0 commit comments

Comments
 (0)