@@ -40,6 +40,8 @@ def merge(
4040 * ,
4141 left_on : blocks .Label | Sequence [blocks .Label ] | None = None ,
4242 right_on : blocks .Label | Sequence [blocks .Label ] | None = None ,
43+ left_index : bool = False ,
44+ right_index : bool = False ,
4345 sort : bool = False ,
4446 suffixes : tuple [str , str ] = ("_x" , "_y" ),
4547) -> dataframe .DataFrame :
@@ -59,42 +61,25 @@ def merge(
5961 )
6062 return dataframe .DataFrame (result_block )
6163
62- left_on , right_on = _validate_left_right_on (
63- left , right , on , left_on = left_on , right_on = right_on
64+ left_join_ids , right_join_ids = _validate_left_right_on (
65+ left ,
66+ right ,
67+ on ,
68+ left_on = left_on ,
69+ right_on = right_on ,
70+ left_index = left_index ,
71+ right_index = right_index ,
6472 )
6573
66- if utils .is_list_like (left_on ):
67- left_on = list (left_on ) # type: ignore
68- else :
69- left_on = [left_on ]
70-
71- if utils .is_list_like (right_on ):
72- right_on = list (right_on ) # type: ignore
73- else :
74- right_on = [right_on ]
75-
76- left_join_ids = []
77- for label in left_on : # type: ignore
78- left_col_id = left ._resolve_label_exact (label )
79- # 0 elements already throws an exception
80- if not left_col_id :
81- raise ValueError (f"No column { label } found in self." )
82- left_join_ids .append (left_col_id )
83-
84- right_join_ids = []
85- for label in right_on : # type: ignore
86- right_col_id = right ._resolve_label_exact (label )
87- if not right_col_id :
88- raise ValueError (f"No column { label } found in other." )
89- right_join_ids .append (right_col_id )
90-
9174 block = left ._block .merge (
9275 right ._block ,
9376 how ,
9477 left_join_ids ,
9578 right_join_ids ,
9679 sort = sort ,
9780 suffixes = suffixes ,
81+ left_index = left_index ,
82+ right_index = right_index
9883 )
9984 return dataframe .DataFrame (block )
10085
@@ -127,30 +112,97 @@ def _validate_left_right_on(
127112 * ,
128113 left_on : blocks .Label | Sequence [blocks .Label ] | None = None ,
129114 right_on : blocks .Label | Sequence [blocks .Label ] | None = None ,
130- ):
131- if on is not None :
115+ left_index : bool = False ,
116+ right_index : bool = False ,
117+ ) -> tuple [list [str ], list [str ]]:
118+ # Turn left_on and right_on to lists
119+ if left_on is not None and not isinstance (left_on , (tuple , list )):
120+ left_on = [left_on ]
121+ if right_on is not None and not isinstance (right_on , (tuple , list )):
122+ right_on = [right_on ]
123+
124+ # The following checks are copied from Pandas.
125+ if on is None and left_on is None and right_on is None :
126+ if left_index and right_index :
127+ return list (left ._block .index_columns ), (right ._block .index_columns )
128+ elif left_index :
129+ raise ValueError ("Must pass right_on or right_index=True" )
130+ elif right_index :
131+ raise ValueError ("Must pass left_on or left_index=True" )
132+ else :
133+ # use the common columns
134+ common_cols = left .columns .intersection (right .columns )
135+ if len (common_cols ) == 0 :
136+ raise ValueError (
137+ "No common columns to perform merge on. "
138+ f"Merge options: left_on={ left_on } , "
139+ f"right_on={ right_on } , "
140+ f"left_index={ left_index } , "
141+ f"right_index={ right_index } "
142+ )
143+ if (
144+ not left .columns .join (common_cols , how = "inner" ).is_unique
145+ or not right .columns .join (common_cols , how = "inner" ).is_unique
146+ ):
147+ raise ValueError (f"Data columns not unique: { repr (common_cols )} " )
148+ return _to_col_ids (left , common_cols ), _to_col_ids (right , common_cols )
149+
150+ elif on is not None :
132151 if left_on is not None or right_on is not None :
133152 raise ValueError (
134- "Can not pass both `on` and `left_on` + `right_on` params."
153+ 'Can only pass argument "on" OR "left_on" '
154+ 'and "right_on", not a combination of both.'
135155 )
136- return on , on
137-
138- if left_on is not None and right_on is not None :
139- return left_on , right_on
140-
141- left_cols = left .columns
142- right_cols = right .columns
143- common_cols = left_cols .intersection (right_cols )
144- if len (common_cols ) == 0 :
145- raise ValueError (
146- "No common columns to perform merge on."
147- f"Merge options: left_on={ left_on } , "
148- f"right_on={ right_on } , "
149- )
150- if (
151- not left_cols .join (common_cols , how = "inner" ).is_unique
152- or not right_cols .join (common_cols , how = "inner" ).is_unique
153- ):
154- raise ValueError (f"Data columns not unique: { repr (common_cols )} " )
156+ if left_index or right_index :
157+ raise ValueError (
158+ 'Can only pass argument "on" OR "left_index" '
159+ 'and "right_index", not a combination of both.'
160+ )
161+ return _to_col_ids (left , on ), _to_col_ids (right , on )
155162
156- return common_cols , common_cols
163+ elif left_on is not None :
164+ if left_index :
165+ raise ValueError (
166+ 'Can only pass argument "left_on" OR "left_index" not both.'
167+ )
168+ if not right_index and right_on is None :
169+ raise ValueError ('Must pass "right_on" OR "right_index".' )
170+ n = len (left_on )
171+ if right_index :
172+ if len (left_on ) != right .index .nlevels :
173+ raise ValueError (
174+ "len(left_on) must equal the number "
175+ 'of levels in the index of "right"'
176+ )
177+ return _to_col_ids (left , left_on ), list (right ._block .index_columns )
178+
179+ elif right_on is not None :
180+ if right_index :
181+ raise ValueError (
182+ 'Can only pass argument "right_on" OR "right_index" not both.'
183+ )
184+ if not left_index and left_on is None :
185+ raise ValueError ('Must pass "left_on" OR "left_index".' )
186+ n = len (right_on )
187+ if left_index :
188+ if len (right_on ) != left .index .nlevels :
189+ raise ValueError (
190+ "len(right_on) must equal the number "
191+ 'of levels in the index of "left"'
192+ )
193+ return list (left ._block .index_columns ), _to_col_ids (right , right_on )
194+
195+ # The user correctly specified left_on and right_on
196+ if len (right_on ) != len (left_on ):
197+ raise ValueError ("len(right_on) must equal len(left_on)" )
198+
199+ return _to_col_ids (left , left_on ), _to_col_ids (right , right_on )
200+
201+
202+ def _to_col_ids (
203+ df : dataframe .DataFrame , join_cols : blocks .Label | Sequence [blocks .Label ]
204+ ) -> list [str ]:
205+ if utils .is_list_like (join_cols ):
206+ return [df ._block .resolve_label_exact_or_error (col ) for col in join_cols ]
207+
208+ return [df ._block .resolve_label_exact_or_error (join_cols )]
0 commit comments