diff --git a/bigframes/core/reshape/merge.py b/bigframes/core/reshape/merge.py index e1750d5c7a..5c6cba4915 100644 --- a/bigframes/core/reshape/merge.py +++ b/bigframes/core/reshape/merge.py @@ -18,20 +18,17 @@ from __future__ import annotations -import typing -from typing import Literal, Optional +from typing import Literal, Sequence import bigframes_vendored.pandas.core.reshape.merge as vendored_pandas_merge -# Avoid cirular imports. -if typing.TYPE_CHECKING: - import bigframes.dataframe - import bigframes.series +from bigframes import dataframe, series +from bigframes.core import blocks, utils def merge( - left: bigframes.dataframe.DataFrame, - right: bigframes.dataframe.DataFrame, + left: dataframe.DataFrame, + right: dataframe.DataFrame, how: Literal[ "inner", "left", @@ -39,33 +36,75 @@ def merge( "right", "cross", ] = "inner", - on: Optional[str] = None, + on: blocks.Label | Sequence[blocks.Label] | None = None, *, - left_on: Optional[str] = None, - right_on: Optional[str] = None, + left_on: blocks.Label | Sequence[blocks.Label] | None = None, + right_on: blocks.Label | Sequence[blocks.Label] | None = None, sort: bool = False, suffixes: tuple[str, str] = ("_x", "_y"), -) -> bigframes.dataframe.DataFrame: +) -> dataframe.DataFrame: left = _validate_operand(left) right = _validate_operand(right) - return left.merge( - right, - how=how, - on=on, - left_on=left_on, - right_on=right_on, + if how == "cross": + if on is not None: + raise ValueError("'on' is not supported for cross join.") + result_block = left._block.merge( + right._block, + left_join_ids=[], + right_join_ids=[], + suffixes=suffixes, + how=how, + sort=True, + ) + return dataframe.DataFrame(result_block) + + left_on, right_on = _validate_left_right_on( + left, right, on, left_on=left_on, right_on=right_on + ) + + if utils.is_list_like(left_on): + left_on = list(left_on) # type: ignore + else: + left_on = [left_on] + + if utils.is_list_like(right_on): + right_on = list(right_on) # type: ignore + else: + right_on = [right_on] + + left_join_ids = [] + for label in left_on: # type: ignore + left_col_id = left._resolve_label_exact(label) + # 0 elements already throws an exception + if not left_col_id: + raise ValueError(f"No column {label} found in self.") + left_join_ids.append(left_col_id) + + right_join_ids = [] + for label in right_on: # type: ignore + right_col_id = right._resolve_label_exact(label) + if not right_col_id: + raise ValueError(f"No column {label} found in other.") + right_join_ids.append(right_col_id) + + block = left._block.merge( + right._block, + how, + left_join_ids, + right_join_ids, sort=sort, suffixes=suffixes, ) + return dataframe.DataFrame(block) merge.__doc__ = vendored_pandas_merge.merge.__doc__ def _validate_operand( - obj: bigframes.dataframe.DataFrame | bigframes.series.Series, -) -> bigframes.dataframe.DataFrame: + obj: dataframe.DataFrame | series.Series, +) -> dataframe.DataFrame: import bigframes.dataframe import bigframes.series @@ -79,3 +118,39 @@ def _validate_operand( raise TypeError( f"Can only merge bigframes.series.Series or bigframes.dataframe.DataFrame objects, a {type(obj)} was passed" ) + + +def _validate_left_right_on( + left: dataframe.DataFrame, + right: dataframe.DataFrame, + on: blocks.Label | Sequence[blocks.Label] | None = None, + *, + left_on: blocks.Label | Sequence[blocks.Label] | None = None, + right_on: blocks.Label | Sequence[blocks.Label] | None = None, +): + if on is not None: + if left_on is not None or right_on is not None: + raise ValueError( + "Can not pass both `on` and `left_on` + `right_on` params." + ) + return on, on + + if left_on is not None and right_on is not None: + return left_on, right_on + + left_cols = left.columns + right_cols = right.columns + common_cols = left_cols.intersection(right_cols) + if len(common_cols) == 0: + raise ValueError( + "No common columns to perform merge on." + f"Merge options: left_on={left_on}, " + f"right_on={right_on}, " + ) + if ( + not left_cols.join(common_cols, how="inner").is_unique + or not right_cols.join(common_cols, how="inner").is_unique + ): + raise ValueError(f"Data columns not unique: {repr(common_cols)}") + + return common_cols, common_cols diff --git a/bigframes/dataframe.py b/bigframes/dataframe.py index f016fddd83..df8c87416f 100644 --- a/bigframes/dataframe.py +++ b/bigframes/dataframe.py @@ -3653,92 +3653,18 @@ def merge( sort: bool = False, suffixes: tuple[str, str] = ("_x", "_y"), ) -> DataFrame: - if how == "cross": - if on is not None: - raise ValueError("'on' is not supported for cross join.") - result_block = self._block.merge( - right._block, - left_join_ids=[], - right_join_ids=[], - suffixes=suffixes, - how=how, - sort=True, - ) - return DataFrame(result_block) - - left_on, right_on = self._validate_left_right_on( - right, on, left_on=left_on, right_on=right_on - ) - - if utils.is_list_like(left_on): - left_on = list(left_on) # type: ignore - else: - left_on = [left_on] + from bigframes.core.reshape import merge - if utils.is_list_like(right_on): - right_on = list(right_on) # type: ignore - else: - right_on = [right_on] - - left_join_ids = [] - for label in left_on: # type: ignore - left_col_id = self._resolve_label_exact(label) - # 0 elements already throws an exception - if not left_col_id: - raise ValueError(f"No column {label} found in self.") - left_join_ids.append(left_col_id) - - right_join_ids = [] - for label in right_on: # type: ignore - right_col_id = right._resolve_label_exact(label) - if not right_col_id: - raise ValueError(f"No column {label} found in other.") - right_join_ids.append(right_col_id) - - block = self._block.merge( - right._block, + return merge.merge( + self, + right, how, - left_join_ids, - right_join_ids, + on, + left_on=left_on, + right_on=right_on, sort=sort, suffixes=suffixes, ) - return DataFrame(block) - - def _validate_left_right_on( - self, - right: DataFrame, - on: Union[blocks.Label, Sequence[blocks.Label], None] = None, - *, - left_on: Union[blocks.Label, Sequence[blocks.Label], None] = None, - right_on: Union[blocks.Label, Sequence[blocks.Label], None] = None, - ): - if on is not None: - if left_on is not None or right_on is not None: - raise ValueError( - "Can not pass both `on` and `left_on` + `right_on` params." - ) - return on, on - - if left_on is not None and right_on is not None: - return left_on, right_on - - left_cols = self.columns - right_cols = right.columns - common_cols = left_cols.intersection(right_cols) - if len(common_cols) == 0: - raise ValueError( - "No common columns to perform merge on." - f"Merge options: left_on={left_on}, " - f"right_on={right_on}, " - ) - if ( - not left_cols.join(common_cols, how="inner").is_unique - or not right_cols.join(common_cols, how="inner").is_unique - ): - raise ValueError(f"Data columns not unique: {repr(common_cols)}") - - return common_cols, common_cols def join( self,