Skip to content

Commit 3d262c3

Browse files
authored
Merge branch 'main' into release-please--branches--main
2 parents 8f768ee + dba2a6e commit 3d262c3

File tree

5 files changed

+170
-19
lines changed

5 files changed

+170
-19
lines changed

bigframes/dataframe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2741,9 +2741,9 @@ def where(self, cond, other=None):
27412741
if isinstance(other, bigframes.series.Series):
27422742
raise ValueError("Seires is not a supported replacement type!")
27432743

2744-
if self.columns.nlevels > 1 or self.index.nlevels > 1:
2744+
if self.columns.nlevels > 1:
27452745
raise NotImplementedError(
2746-
"The dataframe.where() method does not support multi-index and/or multi-column."
2746+
"The dataframe.where() method does not support multi-column."
27472747
)
27482748

27492749
aligned_block, (_, _) = self._block.join(cond._block, how="left")

tests/system/small/test_dataframe.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -375,15 +375,6 @@ def test_insert(scalars_dfs, loc, column, value, allow_duplicates):
375375
pd.testing.assert_frame_equal(bf_df.to_pandas(), pd_df, check_dtype=False)
376376

377377

378-
def test_where_series_cond(scalars_df_index, scalars_pandas_df_index):
379-
# Condition is dataframe, other is None (as default).
380-
cond_bf = scalars_df_index["int64_col"] > 0
381-
cond_pd = scalars_pandas_df_index["int64_col"] > 0
382-
bf_result = scalars_df_index.where(cond_bf).to_pandas()
383-
pd_result = scalars_pandas_df_index.where(cond_pd)
384-
pandas.testing.assert_frame_equal(bf_result, pd_result)
385-
386-
387378
def test_mask_series_cond(scalars_df_index, scalars_pandas_df_index):
388379
cond_bf = scalars_df_index["int64_col"] > 0
389380
cond_pd = scalars_pandas_df_index["int64_col"] > 0
@@ -395,8 +386,8 @@ def test_mask_series_cond(scalars_df_index, scalars_pandas_df_index):
395386
pandas.testing.assert_frame_equal(bf_result, pd_result)
396387

397388

398-
def test_where_series_multi_index(scalars_df_index, scalars_pandas_df_index):
399-
# Test when a dataframe has multi-index or multi-columns.
389+
def test_where_multi_column(scalars_df_index, scalars_pandas_df_index):
390+
# Test when a dataframe has multi-columns.
400391
columns = ["int64_col", "float64_col"]
401392
dataframe_bf = scalars_df_index[columns]
402393

@@ -409,10 +400,19 @@ def test_where_series_multi_index(scalars_df_index, scalars_pandas_df_index):
409400
dataframe_bf.where(cond_bf).to_pandas()
410401
assert (
411402
str(context.value)
412-
== "The dataframe.where() method does not support multi-index and/or multi-column."
403+
== "The dataframe.where() method does not support multi-column."
413404
)
414405

415406

407+
def test_where_series_cond(scalars_df_index, scalars_pandas_df_index):
408+
# Condition is dataframe, other is None (as default).
409+
cond_bf = scalars_df_index["int64_col"] > 0
410+
cond_pd = scalars_pandas_df_index["int64_col"] > 0
411+
bf_result = scalars_df_index.where(cond_bf).to_pandas()
412+
pd_result = scalars_pandas_df_index.where(cond_pd)
413+
pandas.testing.assert_frame_equal(bf_result, pd_result)
414+
415+
416416
def test_where_series_cond_const_other(scalars_df_index, scalars_pandas_df_index):
417417
# Condition is a series, other is a constant.
418418
columns = ["int64_col", "float64_col"]

tests/system/small/test_multiindex.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,22 @@
1919
import bigframes.pandas as bpd
2020
from bigframes.testing.utils import assert_pandas_df_equal
2121

22+
# Sample MultiIndex for testing DataFrames where() method.
23+
_MULTI_INDEX = pandas.MultiIndex.from_tuples(
24+
[
25+
(0, "a"),
26+
(1, "b"),
27+
(2, "c"),
28+
(0, "d"),
29+
(1, "e"),
30+
(2, "f"),
31+
(0, "g"),
32+
(1, "h"),
33+
(2, "i"),
34+
],
35+
names=["A", "B"],
36+
)
37+
2238

2339
def test_multi_index_from_arrays():
2440
bf_idx = bpd.MultiIndex.from_arrays(
@@ -541,6 +557,140 @@ def test_multi_index_dataframe_join_on(scalars_dfs, how):
541557
assert_pandas_df_equal(bf_result, pd_result, ignore_order=True)
542558

543559

560+
def test_multi_index_dataframe_where_series_cond_none_other(
561+
scalars_df_index, scalars_pandas_df_index
562+
):
563+
columns = ["int64_col", "float64_col"]
564+
565+
# Create multi-index dataframe.
566+
dataframe_bf = bpd.DataFrame(
567+
scalars_df_index[columns].values,
568+
index=_MULTI_INDEX,
569+
columns=scalars_df_index[columns].columns,
570+
)
571+
dataframe_pd = pandas.DataFrame(
572+
scalars_pandas_df_index[columns].values,
573+
index=_MULTI_INDEX,
574+
columns=scalars_pandas_df_index[columns].columns,
575+
)
576+
dataframe_bf.columns.name = "test_name"
577+
dataframe_pd.columns.name = "test_name"
578+
579+
# When condition is series and other is None.
580+
series_cond_bf = dataframe_bf["int64_col"] > 0
581+
series_cond_pd = dataframe_pd["int64_col"] > 0
582+
583+
bf_result = dataframe_bf.where(series_cond_bf).to_pandas()
584+
pd_result = dataframe_pd.where(series_cond_pd)
585+
pandas.testing.assert_frame_equal(
586+
bf_result,
587+
pd_result,
588+
check_index_type=False,
589+
check_dtype=False,
590+
)
591+
# Assert the index is still MultiIndex after the operation.
592+
assert isinstance(bf_result.index, pandas.MultiIndex), "Expected a MultiIndex"
593+
assert isinstance(pd_result.index, pandas.MultiIndex), "Expected a MultiIndex"
594+
595+
596+
def test_multi_index_dataframe_where_series_cond_dataframe_other(
597+
scalars_df_index, scalars_pandas_df_index
598+
):
599+
columns = ["int64_col", "int64_too"]
600+
601+
# Create multi-index dataframe.
602+
dataframe_bf = bpd.DataFrame(
603+
scalars_df_index[columns].values,
604+
index=_MULTI_INDEX,
605+
columns=scalars_df_index[columns].columns,
606+
)
607+
dataframe_pd = pandas.DataFrame(
608+
scalars_pandas_df_index[columns].values,
609+
index=_MULTI_INDEX,
610+
columns=scalars_pandas_df_index[columns].columns,
611+
)
612+
613+
# When condition is series and other is dataframe.
614+
series_cond_bf = dataframe_bf["int64_col"] > 1000.0
615+
series_cond_pd = dataframe_pd["int64_col"] > 1000.0
616+
dataframe_other_bf = dataframe_bf * 100.0
617+
dataframe_other_pd = dataframe_pd * 100.0
618+
619+
bf_result = dataframe_bf.where(series_cond_bf, dataframe_other_bf).to_pandas()
620+
pd_result = dataframe_pd.where(series_cond_pd, dataframe_other_pd)
621+
pandas.testing.assert_frame_equal(
622+
bf_result,
623+
pd_result,
624+
check_index_type=False,
625+
check_dtype=False,
626+
)
627+
628+
629+
def test_multi_index_dataframe_where_dataframe_cond_constant_other(
630+
scalars_df_index, scalars_pandas_df_index
631+
):
632+
columns = ["int64_col", "float64_col"]
633+
634+
# Create multi-index dataframe.
635+
dataframe_bf = bpd.DataFrame(
636+
scalars_df_index[columns].values,
637+
index=_MULTI_INDEX,
638+
columns=scalars_df_index[columns].columns,
639+
)
640+
dataframe_pd = pandas.DataFrame(
641+
scalars_pandas_df_index[columns].values,
642+
index=_MULTI_INDEX,
643+
columns=scalars_pandas_df_index[columns].columns,
644+
)
645+
646+
# When condition is dataframe and other is a constant.
647+
dataframe_cond_bf = dataframe_bf > 0
648+
dataframe_cond_pd = dataframe_pd > 0
649+
other = 0
650+
651+
bf_result = dataframe_bf.where(dataframe_cond_bf, other).to_pandas()
652+
pd_result = dataframe_pd.where(dataframe_cond_pd, other)
653+
pandas.testing.assert_frame_equal(
654+
bf_result,
655+
pd_result,
656+
check_index_type=False,
657+
check_dtype=False,
658+
)
659+
660+
661+
def test_multi_index_dataframe_where_dataframe_cond_dataframe_other(
662+
scalars_df_index, scalars_pandas_df_index
663+
):
664+
columns = ["int64_col", "int64_too", "float64_col"]
665+
666+
# Create multi-index dataframe.
667+
dataframe_bf = bpd.DataFrame(
668+
scalars_df_index[columns].values,
669+
index=_MULTI_INDEX,
670+
columns=scalars_df_index[columns].columns,
671+
)
672+
dataframe_pd = pandas.DataFrame(
673+
scalars_pandas_df_index[columns].values,
674+
index=_MULTI_INDEX,
675+
columns=scalars_pandas_df_index[columns].columns,
676+
)
677+
678+
# When condition is dataframe and other is dataframe.
679+
dataframe_cond_bf = dataframe_bf < 1000.0
680+
dataframe_cond_pd = dataframe_pd < 1000.0
681+
dataframe_other_bf = dataframe_bf * -1.0
682+
dataframe_other_pd = dataframe_pd * -1.0
683+
684+
bf_result = dataframe_bf.where(dataframe_cond_bf, dataframe_other_bf).to_pandas()
685+
pd_result = dataframe_pd.where(dataframe_cond_pd, dataframe_other_pd)
686+
pandas.testing.assert_frame_equal(
687+
bf_result,
688+
pd_result,
689+
check_index_type=False,
690+
check_dtype=False,
691+
)
692+
693+
544694
@pytest.mark.parametrize(
545695
("level",),
546696
[

tests/unit/test_dataframe_polars.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,7 @@ def test_where_series_multi_index(scalars_df_index, scalars_pandas_df_index):
364364
dataframe_bf.where(cond_bf).to_pandas()
365365
assert (
366366
str(context.value)
367-
== "The dataframe.where() method does not support multi-index and/or multi-column."
367+
== "The dataframe.where() method does not support multi-column."
368368
)
369369

370370

third_party/bigframes_vendored/tpch/queries/q15.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,11 @@ def q(project_id: str, dataset_id: str, session: bigframes.Session):
3131
.agg(TOTAL_REVENUE=bpd.NamedAgg(column="REVENUE", aggfunc="sum"))
3232
.rename(columns={"L_SUPPKEY": "SUPPLIER_NO"})
3333
)
34+
# Round earlier to prevent non-determinism in the later join due to
35+
# differences in distributed floating point operation sort order.
36+
grouped_revenue = grouped_revenue.assign(
37+
TOTAL_REVENUE=grouped_revenue["TOTAL_REVENUE"].round(2)
38+
)
3439

3540
joined_data = bpd.merge(
3641
supplier, grouped_revenue, left_on="S_SUPPKEY", right_on="SUPPLIER_NO"
@@ -43,10 +48,6 @@ def q(project_id: str, dataset_id: str, session: bigframes.Session):
4348
max_revenue_suppliers = joined_data[
4449
joined_data["TOTAL_REVENUE"] == joined_data["MAX_REVENUE"]
4550
]
46-
47-
max_revenue_suppliers["TOTAL_REVENUE"] = max_revenue_suppliers[
48-
"TOTAL_REVENUE"
49-
].round(2)
5051
q_final = max_revenue_suppliers[
5152
["S_SUPPKEY", "S_NAME", "S_ADDRESS", "S_PHONE", "TOTAL_REVENUE"]
5253
].sort_values("S_SUPPKEY")

0 commit comments

Comments
 (0)