From 7b80215fea5972cd12b549984b1ea11d76162a1a Mon Sep 17 00:00:00 2001 From: jialuo Date: Wed, 6 Aug 2025 23:50:43 +0000 Subject: [PATCH 1/2] feat: Allow callable as a conditional or replacement input in DataFrame.where() --- bigframes/dataframe.py | 6 ++++ tests/system/small/test_dataframe.py | 47 ++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+) diff --git a/bigframes/dataframe.py b/bigframes/dataframe.py index 7de4bdbc91..858be3de45 100644 --- a/bigframes/dataframe.py +++ b/bigframes/dataframe.py @@ -2763,6 +2763,12 @@ def where(self, cond, other=None): "The dataframe.where() method does not support multi-column." ) + # Execute it with the DataFrame when cond or/and other is callable. + if callable(cond): + cond = cond(self) + if callable(other): + other = other(self) + aligned_block, (_, _) = self._block.join(cond._block, how="left") # No left join is needed when 'other' is None or constant. if isinstance(other, bigframes.dataframe.DataFrame): diff --git a/tests/system/small/test_dataframe.py b/tests/system/small/test_dataframe.py index bc773d05b2..a3027917f9 100644 --- a/tests/system/small/test_dataframe.py +++ b/tests/system/small/test_dataframe.py @@ -514,6 +514,53 @@ def test_where_dataframe_cond_dataframe_other( pandas.testing.assert_frame_equal(bf_result, pd_result) +def test_where_callable_cond_constant_other(scalars_df_index, scalars_pandas_df_index): + # Condition is callable, other is a constant. + columns = ["int64_col", "float64_col"] + dataframe_bf = scalars_df_index[columns] + dataframe_pd = scalars_pandas_df_index[columns] + + cond = lambda x: x > 0 + other = 10 + + bf_result = dataframe_bf.where(cond, other).to_pandas() + pd_result = dataframe_pd.where(cond, other) + pandas.testing.assert_frame_equal(bf_result, pd_result) + + +def test_where_dataframe_cond_callable_other(scalars_df_index, scalars_pandas_df_index): + # Condition is a dataframe, other is callable. + columns = ["int64_col", "float64_col"] + dataframe_bf = scalars_df_index[columns] + dataframe_pd = scalars_pandas_df_index[columns] + + cond_bf = dataframe_bf > 0 + cond_pd = dataframe_pd > 0 + + def func(x): + return x * 2 + + bf_result = dataframe_bf.where(cond_bf, func).to_pandas() + pd_result = dataframe_pd.where(cond_pd, func) + pandas.testing.assert_frame_equal(bf_result, pd_result) + + +def test_where_callable_cond_callable_other(scalars_df_index, scalars_pandas_df_index): + # Condition is callable, other is callable too. + columns = ["int64_col", "float64_col"] + dataframe_bf = scalars_df_index[columns] + dataframe_pd = scalars_pandas_df_index[columns] + + def func(x): + return x["int64_col"] > 0 + + other = lambda x: x * 2 + + bf_result = dataframe_bf.where(func, other).to_pandas() + pd_result = dataframe_pd.where(func, other) + pandas.testing.assert_frame_equal(bf_result, pd_result) + + def test_drop_column(scalars_dfs): scalars_df, scalars_pandas_df = scalars_dfs col_name = "int64_col" From ae42e8cf47a8fb673118792c79a5ede0332e40df Mon Sep 17 00:00:00 2001 From: jialuo Date: Thu, 7 Aug 2025 15:58:01 +0000 Subject: [PATCH 2/2] fix lint --- tests/system/small/test_dataframe.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/tests/system/small/test_dataframe.py b/tests/system/small/test_dataframe.py index a3027917f9..bef071f6a7 100644 --- a/tests/system/small/test_dataframe.py +++ b/tests/system/small/test_dataframe.py @@ -520,11 +520,10 @@ def test_where_callable_cond_constant_other(scalars_df_index, scalars_pandas_df_ dataframe_bf = scalars_df_index[columns] dataframe_pd = scalars_pandas_df_index[columns] - cond = lambda x: x > 0 other = 10 - bf_result = dataframe_bf.where(cond, other).to_pandas() - pd_result = dataframe_pd.where(cond, other) + bf_result = dataframe_bf.where(lambda x: x > 0, other).to_pandas() + pd_result = dataframe_pd.where(lambda x: x > 0, other) pandas.testing.assert_frame_equal(bf_result, pd_result) @@ -554,10 +553,8 @@ def test_where_callable_cond_callable_other(scalars_df_index, scalars_pandas_df_ def func(x): return x["int64_col"] > 0 - other = lambda x: x * 2 - - bf_result = dataframe_bf.where(func, other).to_pandas() - pd_result = dataframe_pd.where(func, other) + bf_result = dataframe_bf.where(func, lambda x: x * 2).to_pandas() + pd_result = dataframe_pd.where(func, lambda x: x * 2) pandas.testing.assert_frame_equal(bf_result, pd_result)