diff --git a/bigframes/dataframe.py b/bigframes/dataframe.py index 1ef287842e..1884f0beff 100644 --- a/bigframes/dataframe.py +++ b/bigframes/dataframe.py @@ -2741,9 +2741,9 @@ def where(self, cond, other=None): if isinstance(other, bigframes.series.Series): raise ValueError("Seires is not a supported replacement type!") - if self.columns.nlevels > 1 or self.index.nlevels > 1: + if self.columns.nlevels > 1: raise NotImplementedError( - "The dataframe.where() method does not support multi-index and/or multi-column." + "The dataframe.where() method does not support multi-column." ) aligned_block, (_, _) = self._block.join(cond._block, how="left") diff --git a/tests/system/small/test_dataframe.py b/tests/system/small/test_dataframe.py index 5045e2268f..91a83dfd73 100644 --- a/tests/system/small/test_dataframe.py +++ b/tests/system/small/test_dataframe.py @@ -375,15 +375,6 @@ def test_insert(scalars_dfs, loc, column, value, allow_duplicates): pd.testing.assert_frame_equal(bf_df.to_pandas(), pd_df, check_dtype=False) -def test_where_series_cond(scalars_df_index, scalars_pandas_df_index): - # Condition is dataframe, other is None (as default). - cond_bf = scalars_df_index["int64_col"] > 0 - cond_pd = scalars_pandas_df_index["int64_col"] > 0 - bf_result = scalars_df_index.where(cond_bf).to_pandas() - pd_result = scalars_pandas_df_index.where(cond_pd) - pandas.testing.assert_frame_equal(bf_result, pd_result) - - def test_mask_series_cond(scalars_df_index, scalars_pandas_df_index): cond_bf = scalars_df_index["int64_col"] > 0 cond_pd = scalars_pandas_df_index["int64_col"] > 0 @@ -395,8 +386,8 @@ def test_mask_series_cond(scalars_df_index, scalars_pandas_df_index): pandas.testing.assert_frame_equal(bf_result, pd_result) -def test_where_series_multi_index(scalars_df_index, scalars_pandas_df_index): - # Test when a dataframe has multi-index or multi-columns. +def test_where_multi_column(scalars_df_index, scalars_pandas_df_index): + # Test when a dataframe has multi-columns. columns = ["int64_col", "float64_col"] dataframe_bf = scalars_df_index[columns] @@ -409,10 +400,19 @@ def test_where_series_multi_index(scalars_df_index, scalars_pandas_df_index): dataframe_bf.where(cond_bf).to_pandas() assert ( str(context.value) - == "The dataframe.where() method does not support multi-index and/or multi-column." + == "The dataframe.where() method does not support multi-column." ) +def test_where_series_cond(scalars_df_index, scalars_pandas_df_index): + # Condition is dataframe, other is None (as default). + cond_bf = scalars_df_index["int64_col"] > 0 + cond_pd = scalars_pandas_df_index["int64_col"] > 0 + bf_result = scalars_df_index.where(cond_bf).to_pandas() + pd_result = scalars_pandas_df_index.where(cond_pd) + pandas.testing.assert_frame_equal(bf_result, pd_result) + + def test_where_series_cond_const_other(scalars_df_index, scalars_pandas_df_index): # Condition is a series, other is a constant. columns = ["int64_col", "float64_col"] diff --git a/tests/system/small/test_multiindex.py b/tests/system/small/test_multiindex.py index b63468d311..e4852cc8fb 100644 --- a/tests/system/small/test_multiindex.py +++ b/tests/system/small/test_multiindex.py @@ -19,6 +19,22 @@ import bigframes.pandas as bpd from bigframes.testing.utils import assert_pandas_df_equal +# Sample MultiIndex for testing DataFrames where() method. +_MULTI_INDEX = pandas.MultiIndex.from_tuples( + [ + (0, "a"), + (1, "b"), + (2, "c"), + (0, "d"), + (1, "e"), + (2, "f"), + (0, "g"), + (1, "h"), + (2, "i"), + ], + names=["A", "B"], +) + def test_multi_index_from_arrays(): bf_idx = bpd.MultiIndex.from_arrays( @@ -541,6 +557,140 @@ def test_multi_index_dataframe_join_on(scalars_dfs, how): assert_pandas_df_equal(bf_result, pd_result, ignore_order=True) +def test_multi_index_dataframe_where_series_cond_none_other( + scalars_df_index, scalars_pandas_df_index +): + columns = ["int64_col", "float64_col"] + + # Create multi-index dataframe. + dataframe_bf = bpd.DataFrame( + scalars_df_index[columns].values, + index=_MULTI_INDEX, + columns=scalars_df_index[columns].columns, + ) + dataframe_pd = pandas.DataFrame( + scalars_pandas_df_index[columns].values, + index=_MULTI_INDEX, + columns=scalars_pandas_df_index[columns].columns, + ) + dataframe_bf.columns.name = "test_name" + dataframe_pd.columns.name = "test_name" + + # When condition is series and other is None. + series_cond_bf = dataframe_bf["int64_col"] > 0 + series_cond_pd = dataframe_pd["int64_col"] > 0 + + bf_result = dataframe_bf.where(series_cond_bf).to_pandas() + pd_result = dataframe_pd.where(series_cond_pd) + pandas.testing.assert_frame_equal( + bf_result, + pd_result, + check_index_type=False, + check_dtype=False, + ) + # Assert the index is still MultiIndex after the operation. + assert isinstance(bf_result.index, pandas.MultiIndex), "Expected a MultiIndex" + assert isinstance(pd_result.index, pandas.MultiIndex), "Expected a MultiIndex" + + +def test_multi_index_dataframe_where_series_cond_dataframe_other( + scalars_df_index, scalars_pandas_df_index +): + columns = ["int64_col", "int64_too"] + + # Create multi-index dataframe. + dataframe_bf = bpd.DataFrame( + scalars_df_index[columns].values, + index=_MULTI_INDEX, + columns=scalars_df_index[columns].columns, + ) + dataframe_pd = pandas.DataFrame( + scalars_pandas_df_index[columns].values, + index=_MULTI_INDEX, + columns=scalars_pandas_df_index[columns].columns, + ) + + # When condition is series and other is dataframe. + series_cond_bf = dataframe_bf["int64_col"] > 1000.0 + series_cond_pd = dataframe_pd["int64_col"] > 1000.0 + dataframe_other_bf = dataframe_bf * 100.0 + dataframe_other_pd = dataframe_pd * 100.0 + + bf_result = dataframe_bf.where(series_cond_bf, dataframe_other_bf).to_pandas() + pd_result = dataframe_pd.where(series_cond_pd, dataframe_other_pd) + pandas.testing.assert_frame_equal( + bf_result, + pd_result, + check_index_type=False, + check_dtype=False, + ) + + +def test_multi_index_dataframe_where_dataframe_cond_constant_other( + scalars_df_index, scalars_pandas_df_index +): + columns = ["int64_col", "float64_col"] + + # Create multi-index dataframe. + dataframe_bf = bpd.DataFrame( + scalars_df_index[columns].values, + index=_MULTI_INDEX, + columns=scalars_df_index[columns].columns, + ) + dataframe_pd = pandas.DataFrame( + scalars_pandas_df_index[columns].values, + index=_MULTI_INDEX, + columns=scalars_pandas_df_index[columns].columns, + ) + + # When condition is dataframe and other is a constant. + dataframe_cond_bf = dataframe_bf > 0 + dataframe_cond_pd = dataframe_pd > 0 + other = 0 + + bf_result = dataframe_bf.where(dataframe_cond_bf, other).to_pandas() + pd_result = dataframe_pd.where(dataframe_cond_pd, other) + pandas.testing.assert_frame_equal( + bf_result, + pd_result, + check_index_type=False, + check_dtype=False, + ) + + +def test_multi_index_dataframe_where_dataframe_cond_dataframe_other( + scalars_df_index, scalars_pandas_df_index +): + columns = ["int64_col", "int64_too", "float64_col"] + + # Create multi-index dataframe. + dataframe_bf = bpd.DataFrame( + scalars_df_index[columns].values, + index=_MULTI_INDEX, + columns=scalars_df_index[columns].columns, + ) + dataframe_pd = pandas.DataFrame( + scalars_pandas_df_index[columns].values, + index=_MULTI_INDEX, + columns=scalars_pandas_df_index[columns].columns, + ) + + # When condition is dataframe and other is dataframe. + dataframe_cond_bf = dataframe_bf < 1000.0 + dataframe_cond_pd = dataframe_pd < 1000.0 + dataframe_other_bf = dataframe_bf * -1.0 + dataframe_other_pd = dataframe_pd * -1.0 + + bf_result = dataframe_bf.where(dataframe_cond_bf, dataframe_other_bf).to_pandas() + pd_result = dataframe_pd.where(dataframe_cond_pd, dataframe_other_pd) + pandas.testing.assert_frame_equal( + bf_result, + pd_result, + check_index_type=False, + check_dtype=False, + ) + + @pytest.mark.parametrize( ("level",), [ diff --git a/tests/unit/test_dataframe_polars.py b/tests/unit/test_dataframe_polars.py index 467cf7ce3d..eae800d409 100644 --- a/tests/unit/test_dataframe_polars.py +++ b/tests/unit/test_dataframe_polars.py @@ -364,7 +364,7 @@ def test_where_series_multi_index(scalars_df_index, scalars_pandas_df_index): dataframe_bf.where(cond_bf).to_pandas() assert ( str(context.value) - == "The dataframe.where() method does not support multi-index and/or multi-column." + == "The dataframe.where() method does not support multi-column." )