Skip to content

Commit 94cace3

Browse files
committed
feat: support left_index and right_index for merge
1 parent 5e006e4 commit 94cace3

File tree

4 files changed

+212
-57
lines changed

4 files changed

+212
-57
lines changed

bigframes/core/blocks.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2303,6 +2303,8 @@ def merge(
23032303
right_join_ids: typing.Sequence[str],
23042304
sort: bool,
23052305
suffixes: tuple[str, str] = ("_x", "_y"),
2306+
left_index: bool = False,
2307+
right_index: bool = False,
23062308
) -> Block:
23072309
conditions = tuple(
23082310
(lid, rid) for lid, rid in zip(left_join_ids, right_join_ids)
@@ -2324,9 +2326,8 @@ def merge(
23242326
if col_id in left_join_ids:
23252327
key_part = left_join_ids.index(col_id)
23262328
matching_right_id = right_join_ids[key_part]
2327-
if (
2328-
self.col_id_to_label[col_id]
2329-
== other.col_id_to_label[matching_right_id]
2329+
if self.col_id_to_label[col_id] == other.col_id_to_label.get(
2330+
matching_right_id, None
23302331
):
23312332
matching_join_labels.append(self.col_id_to_label[col_id])
23322333
result_columns.append(coalesced_ids[key_part])
@@ -2371,13 +2372,15 @@ def merge(
23712372
or other.index.is_null
23722373
or self.session._default_index_type == bigframes.enums.DefaultIndexKind.NULL
23732374
):
2374-
expr = joined_expr
2375-
index_columns = []
2375+
return Block(joined_expr, index_columns=[], column_labels=labels)
2376+
elif left_index:
2377+
return Block(joined_expr, index_columns=[left_post_join_ids], column_labels=labels)
2378+
elif right_index:
2379+
return Block(joined_expr, index_columns=[right_post_join_ids], column_labels=labels)
23762380
else:
23772381
expr, offset_index_id = joined_expr.promote_offsets()
23782382
index_columns = [offset_index_id]
2379-
2380-
return Block(expr, index_columns=index_columns, column_labels=labels)
2383+
return Block(expr, index_columns=index_columns, column_labels=labels)
23812384

23822385
def _align_both_axes(
23832386
self, other: Block, how: str

bigframes/core/reshape/merge.py

Lines changed: 102 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ def merge(
4040
*,
4141
left_on: blocks.Label | Sequence[blocks.Label] | None = None,
4242
right_on: blocks.Label | Sequence[blocks.Label] | None = None,
43+
left_index: bool = False,
44+
right_index: bool = False,
4345
sort: bool = False,
4446
suffixes: tuple[str, str] = ("_x", "_y"),
4547
) -> dataframe.DataFrame:
@@ -59,42 +61,25 @@ def merge(
5961
)
6062
return dataframe.DataFrame(result_block)
6163

62-
left_on, right_on = _validate_left_right_on(
63-
left, right, on, left_on=left_on, right_on=right_on
64+
left_join_ids, right_join_ids = _validate_left_right_on(
65+
left,
66+
right,
67+
on,
68+
left_on=left_on,
69+
right_on=right_on,
70+
left_index=left_index,
71+
right_index=right_index,
6472
)
6573

66-
if utils.is_list_like(left_on):
67-
left_on = list(left_on) # type: ignore
68-
else:
69-
left_on = [left_on]
70-
71-
if utils.is_list_like(right_on):
72-
right_on = list(right_on) # type: ignore
73-
else:
74-
right_on = [right_on]
75-
76-
left_join_ids = []
77-
for label in left_on: # type: ignore
78-
left_col_id = left._resolve_label_exact(label)
79-
# 0 elements already throws an exception
80-
if not left_col_id:
81-
raise ValueError(f"No column {label} found in self.")
82-
left_join_ids.append(left_col_id)
83-
84-
right_join_ids = []
85-
for label in right_on: # type: ignore
86-
right_col_id = right._resolve_label_exact(label)
87-
if not right_col_id:
88-
raise ValueError(f"No column {label} found in other.")
89-
right_join_ids.append(right_col_id)
90-
9174
block = left._block.merge(
9275
right._block,
9376
how,
9477
left_join_ids,
9578
right_join_ids,
9679
sort=sort,
9780
suffixes=suffixes,
81+
left_index=left_index,
82+
right_index=right_index
9883
)
9984
return dataframe.DataFrame(block)
10085

@@ -127,30 +112,97 @@ def _validate_left_right_on(
127112
*,
128113
left_on: blocks.Label | Sequence[blocks.Label] | None = None,
129114
right_on: blocks.Label | Sequence[blocks.Label] | None = None,
130-
):
131-
if on is not None:
115+
left_index: bool = False,
116+
right_index: bool = False,
117+
) -> tuple[list[str], list[str]]:
118+
# Turn left_on and right_on to lists
119+
if left_on is not None and not isinstance(left_on, (tuple, list)):
120+
left_on = [left_on]
121+
if right_on is not None and not isinstance(right_on, (tuple, list)):
122+
right_on = [right_on]
123+
124+
# The following checks are copied from Pandas.
125+
if on is None and left_on is None and right_on is None:
126+
if left_index and right_index:
127+
return list(left._block.index_columns), (right._block.index_columns)
128+
elif left_index:
129+
raise ValueError("Must pass right_on or right_index=True")
130+
elif right_index:
131+
raise ValueError("Must pass left_on or left_index=True")
132+
else:
133+
# use the common columns
134+
common_cols = left.columns.intersection(right.columns)
135+
if len(common_cols) == 0:
136+
raise ValueError(
137+
"No common columns to perform merge on. "
138+
f"Merge options: left_on={left_on}, "
139+
f"right_on={right_on}, "
140+
f"left_index={left_index}, "
141+
f"right_index={right_index}"
142+
)
143+
if (
144+
not left.columns.join(common_cols, how="inner").is_unique
145+
or not right.columns.join(common_cols, how="inner").is_unique
146+
):
147+
raise ValueError(f"Data columns not unique: {repr(common_cols)}")
148+
return _to_col_ids(left, common_cols), _to_col_ids(right, common_cols)
149+
150+
elif on is not None:
132151
if left_on is not None or right_on is not None:
133152
raise ValueError(
134-
"Can not pass both `on` and `left_on` + `right_on` params."
153+
'Can only pass argument "on" OR "left_on" '
154+
'and "right_on", not a combination of both.'
135155
)
136-
return on, on
137-
138-
if left_on is not None and right_on is not None:
139-
return left_on, right_on
140-
141-
left_cols = left.columns
142-
right_cols = right.columns
143-
common_cols = left_cols.intersection(right_cols)
144-
if len(common_cols) == 0:
145-
raise ValueError(
146-
"No common columns to perform merge on."
147-
f"Merge options: left_on={left_on}, "
148-
f"right_on={right_on}, "
149-
)
150-
if (
151-
not left_cols.join(common_cols, how="inner").is_unique
152-
or not right_cols.join(common_cols, how="inner").is_unique
153-
):
154-
raise ValueError(f"Data columns not unique: {repr(common_cols)}")
156+
if left_index or right_index:
157+
raise ValueError(
158+
'Can only pass argument "on" OR "left_index" '
159+
'and "right_index", not a combination of both.'
160+
)
161+
return _to_col_ids(left, on), _to_col_ids(right, on)
155162

156-
return common_cols, common_cols
163+
elif left_on is not None:
164+
if left_index:
165+
raise ValueError(
166+
'Can only pass argument "left_on" OR "left_index" not both.'
167+
)
168+
if not right_index and right_on is None:
169+
raise ValueError('Must pass "right_on" OR "right_index".')
170+
n = len(left_on)
171+
if right_index:
172+
if len(left_on) != right.index.nlevels:
173+
raise ValueError(
174+
"len(left_on) must equal the number "
175+
'of levels in the index of "right"'
176+
)
177+
return _to_col_ids(left, left_on), list(right._block.index_columns)
178+
179+
elif right_on is not None:
180+
if right_index:
181+
raise ValueError(
182+
'Can only pass argument "right_on" OR "right_index" not both.'
183+
)
184+
if not left_index and left_on is None:
185+
raise ValueError('Must pass "left_on" OR "left_index".')
186+
n = len(right_on)
187+
if left_index:
188+
if len(right_on) != left.index.nlevels:
189+
raise ValueError(
190+
"len(right_on) must equal the number "
191+
'of levels in the index of "left"'
192+
)
193+
return list(left._block.index_columns), _to_col_ids(right, right_on)
194+
195+
# The user correctly specified left_on and right_on
196+
if len(right_on) != len(left_on):
197+
raise ValueError("len(right_on) must equal len(left_on)")
198+
199+
return _to_col_ids(left, left_on), _to_col_ids(right, right_on)
200+
201+
202+
def _to_col_ids(
203+
df: dataframe.DataFrame, join_cols: blocks.Label | Sequence[blocks.Label]
204+
) -> list[str]:
205+
if utils.is_list_like(join_cols):
206+
return [df._block.resolve_label_exact_or_error(col) for col in join_cols]
207+
208+
return [df._block.resolve_label_exact_or_error(join_cols)]

bigframes/dataframe.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3650,6 +3650,8 @@ def merge(
36503650
*,
36513651
left_on: Union[blocks.Label, Sequence[blocks.Label], None] = None,
36523652
right_on: Union[blocks.Label, Sequence[blocks.Label], None] = None,
3653+
left_index: bool = False,
3654+
right_index: bool = False,
36533655
sort: bool = False,
36543656
suffixes: tuple[str, str] = ("_x", "_y"),
36553657
) -> DataFrame:
@@ -3662,6 +3664,8 @@ def merge(
36623664
on,
36633665
left_on=left_on,
36643666
right_on=right_on,
3667+
left_index=left_index,
3668+
right_index=right_index,
36653669
sort=sort,
36663670
suffixes=suffixes,
36673671
)
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pandas as pd
16+
import pandas.testing
17+
import pytest
18+
19+
from bigframes import session
20+
from bigframes.core.reshape import merge
21+
22+
23+
@pytest.mark.parametrize(
24+
("left_on", "right_on", "left_index", "right_index"),
25+
[
26+
("col_a", None, False, True),
27+
(None, "col_c", True, False),
28+
(None, None, True, True),
29+
],
30+
)
31+
def test_join_with_index(
32+
session: session.Session, left_on, right_on, left_index, right_index
33+
):
34+
df1 = pd.DataFrame({"col_a": [1, 2, 3], "col_b": [2, 3, 4]})
35+
bf1 = session.read_pandas(df1)
36+
df2 = pd.DataFrame({"col_c": [1, 2, 3], "col_d": [2, 3, 4]})
37+
bf2 = session.read_pandas(df2)
38+
39+
bf_result = merge.merge(
40+
bf1,
41+
bf2,
42+
left_on=left_on,
43+
right_on=right_on,
44+
left_index=left_index,
45+
right_index=right_index,
46+
).to_pandas()
47+
pd_result = pd.merge(
48+
df1,
49+
df2,
50+
left_on=left_on,
51+
right_on=right_on,
52+
left_index=left_index,
53+
right_index=right_index,
54+
)
55+
56+
pandas.testing.assert_frame_equal(
57+
bf_result, pd_result, check_dtype=False, check_index_type=False
58+
)
59+
60+
@pytest.mark.parametrize(
61+
("left_on", "right_on", "left_index", "right_index"),
62+
[
63+
(["col_a", "col_b"], None, False, True),
64+
(None, ["col_c", "col_d"], True, False),
65+
(None, None, True, True),
66+
],
67+
)
68+
def test_join_with_multiindex(
69+
session: session.Session, left_on, right_on, left_index, right_index
70+
):
71+
multi_idx = pd.MultiIndex.from_tuples([(1,2), (2, 3), (3,4)])
72+
df1 = pd.DataFrame({"col_a": [1, 2, 3], "col_b": [2, 3, 4]}, index=multi_idx)
73+
bf1 = session.read_pandas(df1)
74+
df2 = pd.DataFrame({"col_c": [1, 2, 3], "col_d": [2, 3, 4]}, index=multi_idx)
75+
bf2 = session.read_pandas(df2)
76+
77+
bf_result = merge.merge(
78+
bf1,
79+
bf2,
80+
left_on=left_on,
81+
right_on=right_on,
82+
left_index=left_index,
83+
right_index=right_index,
84+
).to_pandas()
85+
pd_result = pd.merge(
86+
df1,
87+
df2,
88+
left_on=left_on,
89+
right_on=right_on,
90+
left_index=left_index,
91+
right_index=right_index,
92+
)
93+
94+
pandas.testing.assert_frame_equal(
95+
bf_result, pd_result, check_dtype=False, check_index_type=False
96+
)

0 commit comments

Comments
 (0)