Skip to content

Commit e41e83d

Browse files
committed
optimize Series.dot for dataframe with single level columns
1 parent f794ff6 commit e41e83d

File tree

2 files changed

+29
-8
lines changed

2 files changed

+29
-8
lines changed

bigframes/series.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -680,8 +680,25 @@ def rdivmod(self, other) -> Tuple[Series, Series]: # type: ignore
680680
return (self.rfloordiv(other), self.rmod(other))
681681

682682
def __matmul__(self, other):
683-
if isinstance(other, bigframes.dataframe.DataFrame):
684-
return Series(
683+
if isinstance(other, Series):
684+
return (self * other).sum()
685+
686+
# At this point other must be a DataFrame
687+
if len(other.columns.names) == 1:
688+
# Single level columns in other
689+
na_df = other.isna().any()
690+
mul_df = Series(
691+
[(self * other[col]).sum() for col in other.columns],
692+
index=other.columns,
693+
name=self.name,
694+
)
695+
result = mul_df.mask(na_df)
696+
else:
697+
# Multi level columns in other
698+
# TODO(b/313747368): Remove this once DataFrame.any() honors
699+
# multi-level index, as the logic in the if clause should generalize
700+
# for multi-level columns in other
701+
result = Series(
685702
[
686703
pandas.NA if other[col].isna().any() else (self * other[col]).sum()
687704
for col in other.columns
@@ -690,8 +707,7 @@ def __matmul__(self, other):
690707
name=self.name,
691708
)
692709

693-
# At this point other must be a Series
694-
return (self * other).sum()
710+
return result
695711

696712
dot = __matmul__
697713

tests/system/small/test_multiindex.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1055,10 +1055,15 @@ def test_series_dot_df_column_multi_index():
10551055
[["col0", "col0", "col1"], ["col00", "col01", "col11"]]
10561056
)
10571057

1058-
bf_result = bpd.Series(left) @ bpd.DataFrame(right, columns=multi_level_columns)
1059-
pd_result = pandas.Series(left) @ pandas.DataFrame(
1060-
right, columns=multi_level_columns
1061-
)
1058+
bf_left_s = bpd.Series(left)
1059+
bf_right_df = bpd.DataFrame(right)
1060+
bf_right_df.columns = multi_level_columns
1061+
bf_result = bf_left_s @ bf_right_df
1062+
1063+
pd_left_s = pandas.Series(left)
1064+
pd_right_df = pandas.DataFrame(right)
1065+
pd_right_df.columns = multi_level_columns
1066+
pd_result = pd_left_s @ pd_right_df
10621067

10631068
pandas.testing.assert_series_equal(
10641069
bf_result.to_pandas(), pd_result, check_index_type=False, check_dtype=False

0 commit comments

Comments
 (0)