Skip to content

Commit 47bfeb9

Browse files
committed
feat: support Series.dot on a DataFrame input
1 parent 71844b0 commit 47bfeb9

File tree

2 files changed

+36
-0
lines changed

2 files changed

+36
-0
lines changed

bigframes/series.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -680,6 +680,17 @@ def rdivmod(self, other) -> Tuple[Series, Series]: # type: ignore
680680
return (self.rfloordiv(other), self.rmod(other))
681681

682682
def __matmul__(self, other):
683+
if isinstance(other, bigframes.dataframe.DataFrame):
684+
return Series(
685+
[
686+
pandas.NA if other[col].isna().any() else (self * other[col]).sum()
687+
for col in other.columns
688+
],
689+
index=other.columns,
690+
name=self.name,
691+
)
692+
693+
# At this point other must be a Series
683694
return (self * other).sum()
684695

685696
dot = __matmul__

tests/system/small/test_series.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import pyarrow as pa # type: ignore
2323
import pytest
2424

25+
import bigframes.dataframe as dataframe
2526
import bigframes.pandas
2627
import bigframes.series as series
2728
from tests.system.utils import assert_pandas_df_equal, assert_series_equal
@@ -2266,6 +2267,30 @@ def test_dot(scalars_dfs):
22662267
assert bf_result == pd_result
22672268

22682269

2270+
def test_dot_df(scalars_dfs):
2271+
scalars_df, scalars_pandas_df = scalars_dfs
2272+
bf_result = scalars_df["int64_too"] @ scalars_df[["int64_col", "int64_too"]]
2273+
pd_result = (
2274+
scalars_pandas_df["int64_too"] @ scalars_pandas_df[["int64_col", "int64_too"]]
2275+
)
2276+
2277+
pd.testing.assert_series_equal(
2278+
bf_result.to_pandas(), pd_result, check_index_type=False, check_dtype=False
2279+
)
2280+
2281+
2282+
def test_dot_df_inline(scalars_dfs):
2283+
left = [10, 11, 12, 13] # series data
2284+
right = [[0, 1], [-2, 3], [4, -5], [6, 7]] # dataframe data
2285+
2286+
bf_result = series.Series(left) @ dataframe.DataFrame(right)
2287+
pd_result = pd.Series(left) @ pd.DataFrame(right)
2288+
2289+
pd.testing.assert_series_equal(
2290+
bf_result.to_pandas(), pd_result, check_index_type=False, check_dtype=False
2291+
)
2292+
2293+
22692294
@pytest.mark.parametrize(
22702295
("left", "right", "inclusive"),
22712296
[

0 commit comments

Comments
 (0)