diff --git a/bigframes/dataframe.py b/bigframes/dataframe.py index 7de4bdbc91..858be3de45 100644 --- a/bigframes/dataframe.py +++ b/bigframes/dataframe.py @@ -2763,6 +2763,12 @@ def where(self, cond, other=None): "The dataframe.where() method does not support multi-column." ) + # Execute it with the DataFrame when cond or/and other is callable. + if callable(cond): + cond = cond(self) + if callable(other): + other = other(self) + aligned_block, (_, _) = self._block.join(cond._block, how="left") # No left join is needed when 'other' is None or constant. if isinstance(other, bigframes.dataframe.DataFrame): diff --git a/tests/system/small/test_dataframe.py b/tests/system/small/test_dataframe.py index bc773d05b2..bef071f6a7 100644 --- a/tests/system/small/test_dataframe.py +++ b/tests/system/small/test_dataframe.py @@ -514,6 +514,50 @@ def test_where_dataframe_cond_dataframe_other( pandas.testing.assert_frame_equal(bf_result, pd_result) +def test_where_callable_cond_constant_other(scalars_df_index, scalars_pandas_df_index): + # Condition is callable, other is a constant. + columns = ["int64_col", "float64_col"] + dataframe_bf = scalars_df_index[columns] + dataframe_pd = scalars_pandas_df_index[columns] + + other = 10 + + bf_result = dataframe_bf.where(lambda x: x > 0, other).to_pandas() + pd_result = dataframe_pd.where(lambda x: x > 0, other) + pandas.testing.assert_frame_equal(bf_result, pd_result) + + +def test_where_dataframe_cond_callable_other(scalars_df_index, scalars_pandas_df_index): + # Condition is a dataframe, other is callable. + columns = ["int64_col", "float64_col"] + dataframe_bf = scalars_df_index[columns] + dataframe_pd = scalars_pandas_df_index[columns] + + cond_bf = dataframe_bf > 0 + cond_pd = dataframe_pd > 0 + + def func(x): + return x * 2 + + bf_result = dataframe_bf.where(cond_bf, func).to_pandas() + pd_result = dataframe_pd.where(cond_pd, func) + pandas.testing.assert_frame_equal(bf_result, pd_result) + + +def test_where_callable_cond_callable_other(scalars_df_index, scalars_pandas_df_index): + # Condition is callable, other is callable too. + columns = ["int64_col", "float64_col"] + dataframe_bf = scalars_df_index[columns] + dataframe_pd = scalars_pandas_df_index[columns] + + def func(x): + return x["int64_col"] > 0 + + bf_result = dataframe_bf.where(func, lambda x: x * 2).to_pandas() + pd_result = dataframe_pd.where(func, lambda x: x * 2) + pandas.testing.assert_frame_equal(bf_result, pd_result) + + def test_drop_column(scalars_dfs): scalars_df, scalars_pandas_df = scalars_dfs col_name = "int64_col"