@@ -159,14 +159,14 @@ def _flatten_array_of_struct_columns(
159159 nested_originated_columns .add (new_col_name )
160160 new_array_col_names .append (new_col_name )
161161
162- # Reconstruct ListArray for this field
163- # Use mask=arrow_array.is_null() to preserve nulls from the original list
162+ # Reconstruct ListArray for this field. This transforms the
163+ # array<struct<f1, f2>> into separate array<f1> and array<f2> columns.
164164 new_list_array = pa .ListArray .from_arrays (
165165 offsets , flattened_fields [field_idx ], mask = arrow_array .is_null ()
166166 )
167167
168168 new_cols_to_add [new_col_name ] = pd .Series (
169- new_list_array . to_pylist () ,
169+ new_list_array ,
170170 dtype = pd .ArrowDtype (pa .list_ (field .type )),
171171 index = result_df .index ,
172172 )
@@ -194,78 +194,79 @@ def _explode_array_columns(
194194 dataframe : pd .DataFrame , array_columns : list [str ]
195195) -> tuple [pd .DataFrame , dict [str , list [int ]]]:
196196 """Explode array columns into new rows."""
197- exploded_rows = []
198- array_row_groups : dict [str , list [int ]] = {}
197+ if not array_columns :
198+ return dataframe , {}
199+
199200 non_array_columns = dataframe .columns .drop (array_columns ).tolist ()
200- non_array_df = dataframe [non_array_columns ]
201-
202- for orig_idx in dataframe .index :
203- non_array_data = non_array_df .loc [orig_idx ].to_dict ()
204- array_values = {}
205- max_len_in_row = 0
206- non_na_array_found = False
207-
208- for col_name in array_columns :
209- val = dataframe .loc [orig_idx , col_name ]
210- if val is not None and not (
211- isinstance (val , list ) and len (val ) == 1 and pd .isna (val [0 ])
212- ):
213- array_values [col_name ] = list (val )
214- max_len_in_row = max (max_len_in_row , len (val ))
215- non_na_array_found = True
216- else :
217- array_values [col_name ] = []
218-
219- if not non_na_array_found :
220- new_row = non_array_data .copy ()
221- for col_name in array_columns :
222- new_row [f"{ col_name } " ] = pd .NA
223- exploded_rows .append (new_row )
224- orig_key = str (orig_idx )
225- if orig_key not in array_row_groups :
226- array_row_groups [orig_key ] = []
227- array_row_groups [orig_key ].append (len (exploded_rows ) - 1 )
228- continue
229-
230- # Create one row per array element, up to max_len_in_row
231- for array_idx in range (max_len_in_row ):
232- new_row = non_array_data .copy ()
233-
234- # Add the specific array element for this index
235- for col_name in array_columns :
236- if array_idx < len (array_values .get (col_name , [])):
237- new_row [f"{ col_name } " ] = array_values [col_name ][array_idx ]
238- else :
239- new_row [f"{ col_name } " ] = pd .NA
240-
241- exploded_rows .append (new_row )
242-
243- # Track which rows belong to which original row
244- orig_key = str (orig_idx )
245- if orig_key not in array_row_groups :
246- array_row_groups [orig_key ] = []
247- array_row_groups [orig_key ].append (len (exploded_rows ) - 1 )
248-
249- if exploded_rows :
250- # Reconstruct the DataFrame to maintain original column order
251- exploded_df = pd .DataFrame (exploded_rows )[dataframe .columns ]
252- for col in exploded_df .columns :
253- # After explosion, object columns that are all-numeric (except for NAs)
254- # should be converted to a numeric dtype for proper alignment.
255- if exploded_df [col ].dtype == "object" :
256- try :
257- # Use nullable integer type to preserve integers
258- exploded_df [col ] = exploded_df [col ].astype (pd .Int64Dtype ())
259- except (ValueError , TypeError ):
260- # Fallback for non-integer numerics
261- try :
262- exploded_df [col ] = pd .to_numeric (exploded_df [col ])
263- except (ValueError , TypeError ):
264- # Keep as object if not numeric
265- pass
266- return exploded_df , array_row_groups
201+ if not non_array_columns :
202+ # Add a temporary column to allow grouping if all columns are arrays
203+ non_array_columns = ["_temp_grouping_col" ]
204+ dataframe ["_temp_grouping_col" ] = range (len (dataframe ))
205+
206+ # Preserve original index
207+ if dataframe .index .name :
208+ original_index_name = dataframe .index .name
209+ dataframe = dataframe .reset_index ()
210+ non_array_columns .append (original_index_name )
211+ else :
212+ original_index_name = None
213+ dataframe = dataframe .reset_index (names = ["_original_index" ])
214+ non_array_columns .append ("_original_index" )
215+
216+ exploded_dfs = []
217+ for col in array_columns :
218+ # Explode each array column individually
219+ exploded = dataframe [non_array_columns + [col ]].explode (col )
220+ exploded ["_row_num" ] = exploded .groupby (non_array_columns ).cumcount ()
221+ exploded_dfs .append (exploded )
222+
223+ if not exploded_dfs :
224+ return dataframe , {}
225+
226+ # Merge the exploded columns
227+ merged_df = exploded_dfs [0 ]
228+ for i in range (1 , len (exploded_dfs )):
229+ merged_df = pd .merge (
230+ merged_df ,
231+ exploded_dfs [i ],
232+ on = non_array_columns + ["_row_num" ],
233+ how = "outer" ,
234+ )
235+
236+ # Restore original column order and sort
237+ final_cols = dataframe .columns .tolist () + ["_row_num" ]
238+ merged_df = merged_df .sort_values (non_array_columns + ["_row_num" ]).reset_index (
239+ drop = True
240+ )
241+
242+ # Create row groups
243+ array_row_groups = {}
244+ if "_original_index" in merged_df .columns :
245+ grouping_col = "_original_index"
246+ elif original_index_name :
247+ grouping_col = original_index_name
267248 else :
268- return dataframe , array_row_groups
249+ # Fallback if no clear grouping column is identified
250+ grouping_col = non_array_columns [0 ]
251+
252+ for orig_idx , group in merged_df .groupby (grouping_col ):
253+ array_row_groups [str (orig_idx )] = group .index .tolist ()
254+
255+ # Clean up temporary columns
256+ if "_temp_grouping_col" in merged_df .columns :
257+ merged_df = merged_df .drop (columns = ["_temp_grouping_col" ])
258+ final_cols .remove ("_temp_grouping_col" )
259+ if "_original_index" in merged_df .columns :
260+ merged_df = merged_df .drop (columns = ["_original_index" ])
261+ final_cols .remove ("_original_index" )
262+ if original_index_name :
263+ merged_df = merged_df .set_index (original_index_name )
264+ final_cols .remove (original_index_name )
265+
266+ final_cols .remove ("_row_num" )
267+ merged_df = merged_df [final_cols ]
268+
269+ return merged_df , array_row_groups
269270
270271
271272def _flatten_struct_columns (
@@ -295,7 +296,7 @@ def _flatten_struct_columns(
295296
296297 # Create a new Series from the flattened array
297298 new_cols_to_add [new_col_name ] = pd .Series (
298- flattened_fields [field_idx ]. to_pylist () ,
299+ flattened_fields [field_idx ],
299300 dtype = pd .ArrowDtype (field .type ),
300301 index = result_df .index ,
301302 )
0 commit comments