@@ -197,30 +197,35 @@ def _explode_array_columns(
197197 if not array_columns :
198198 return dataframe , {}
199199
200- non_array_columns = dataframe .columns .drop (array_columns ).tolist ()
200+ original_cols = dataframe .columns .tolist ()
201+ work_df = dataframe
202+
203+ non_array_columns = work_df .columns .drop (array_columns ).tolist ()
201204 if not non_array_columns :
205+ work_df = work_df .copy () # Avoid modifying input
202206 # Add a temporary column to allow grouping if all columns are arrays
203207 non_array_columns = ["_temp_grouping_col" ]
204- dataframe ["_temp_grouping_col" ] = range (len (dataframe ))
208+ work_df ["_temp_grouping_col" ] = range (len (work_df ))
205209
206210 # Preserve original index
207- if dataframe .index .name :
208- original_index_name = dataframe .index .name
209- dataframe = dataframe .reset_index ()
211+ if work_df .index .name :
212+ original_index_name = work_df .index .name
213+ work_df = work_df .reset_index ()
210214 non_array_columns .append (original_index_name )
211215 else :
212216 original_index_name = None
213- dataframe = dataframe .reset_index (names = ["_original_index" ])
217+ work_df = work_df .reset_index (names = ["_original_index" ])
214218 non_array_columns .append ("_original_index" )
215219
216220 exploded_dfs = []
217221 for col in array_columns :
218222 # Explode each array column individually
219- exploded = dataframe [non_array_columns + [col ]].explode (col )
223+ exploded = work_df [non_array_columns + [col ]].explode (col )
220224 exploded ["_row_num" ] = exploded .groupby (non_array_columns ).cumcount ()
221225 exploded_dfs .append (exploded )
222226
223227 if not exploded_dfs :
228+ # This should not be reached if array_columns is not empty
224229 return dataframe , {}
225230
226231 # Merge the exploded columns
@@ -234,39 +239,26 @@ def _explode_array_columns(
234239 )
235240
236241 # Restore original column order and sort
237- final_cols = dataframe .columns .tolist () + ["_row_num" ]
238242 merged_df = merged_df .sort_values (non_array_columns + ["_row_num" ]).reset_index (
239243 drop = True
240244 )
241245
242246 # Create row groups
243247 array_row_groups = {}
244- if "_original_index" in merged_df .columns :
245- grouping_col = "_original_index"
246- elif original_index_name :
247- grouping_col = original_index_name
248- else :
249- # Fallback if no clear grouping column is identified
250- grouping_col = non_array_columns [0 ]
251-
252- for orig_idx , group in merged_df .groupby (grouping_col ):
253- array_row_groups [str (orig_idx )] = group .index .tolist ()
254-
255- # Clean up temporary columns
256- if "_temp_grouping_col" in merged_df .columns :
257- merged_df = merged_df .drop (columns = ["_temp_grouping_col" ])
258- final_cols .remove ("_temp_grouping_col" )
259- if "_original_index" in merged_df .columns :
260- merged_df = merged_df .drop (columns = ["_original_index" ])
261- final_cols .remove ("_original_index" )
262- if original_index_name :
263- merged_df = merged_df .set_index (original_index_name )
264- final_cols .remove (original_index_name )
248+ grouping_col_name = (
249+ "_original_index" if original_index_name is None else original_index_name
250+ )
251+ if grouping_col_name in merged_df .columns :
252+ for orig_idx , group in merged_df .groupby (grouping_col_name ):
253+ array_row_groups [str (orig_idx )] = group .index .tolist ()
254+
255+ # Restore original columns
256+ result_df = merged_df [original_cols ]
265257
266- final_cols . remove ( "_row_num" )
267- merged_df = merged_df [ final_cols ]
258+ if original_index_name :
259+ result_df = result_df . set_index ( original_index_name )
268260
269- return merged_df , array_row_groups
261+ return result_df , array_row_groups
270262
271263
272264def _flatten_struct_columns (
0 commit comments