Skip to content

Commit e364674

Browse files
committed
Fix: Correct bug in nested data flattening
1 parent 41df7b3 commit e364674

File tree

1 file changed

+24
-32
lines changed

1 file changed

+24
-32
lines changed

bigframes/display/_flatten.py

Lines changed: 24 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -197,30 +197,35 @@ def _explode_array_columns(
197197
if not array_columns:
198198
return dataframe, {}
199199

200-
non_array_columns = dataframe.columns.drop(array_columns).tolist()
200+
original_cols = dataframe.columns.tolist()
201+
work_df = dataframe
202+
203+
non_array_columns = work_df.columns.drop(array_columns).tolist()
201204
if not non_array_columns:
205+
work_df = work_df.copy() # Avoid modifying input
202206
# Add a temporary column to allow grouping if all columns are arrays
203207
non_array_columns = ["_temp_grouping_col"]
204-
dataframe["_temp_grouping_col"] = range(len(dataframe))
208+
work_df["_temp_grouping_col"] = range(len(work_df))
205209

206210
# Preserve original index
207-
if dataframe.index.name:
208-
original_index_name = dataframe.index.name
209-
dataframe = dataframe.reset_index()
211+
if work_df.index.name:
212+
original_index_name = work_df.index.name
213+
work_df = work_df.reset_index()
210214
non_array_columns.append(original_index_name)
211215
else:
212216
original_index_name = None
213-
dataframe = dataframe.reset_index(names=["_original_index"])
217+
work_df = work_df.reset_index(names=["_original_index"])
214218
non_array_columns.append("_original_index")
215219

216220
exploded_dfs = []
217221
for col in array_columns:
218222
# Explode each array column individually
219-
exploded = dataframe[non_array_columns + [col]].explode(col)
223+
exploded = work_df[non_array_columns + [col]].explode(col)
220224
exploded["_row_num"] = exploded.groupby(non_array_columns).cumcount()
221225
exploded_dfs.append(exploded)
222226

223227
if not exploded_dfs:
228+
# This should not be reached if array_columns is not empty
224229
return dataframe, {}
225230

226231
# Merge the exploded columns
@@ -234,39 +239,26 @@ def _explode_array_columns(
234239
)
235240

236241
# Restore original column order and sort
237-
final_cols = dataframe.columns.tolist() + ["_row_num"]
238242
merged_df = merged_df.sort_values(non_array_columns + ["_row_num"]).reset_index(
239243
drop=True
240244
)
241245

242246
# Create row groups
243247
array_row_groups = {}
244-
if "_original_index" in merged_df.columns:
245-
grouping_col = "_original_index"
246-
elif original_index_name:
247-
grouping_col = original_index_name
248-
else:
249-
# Fallback if no clear grouping column is identified
250-
grouping_col = non_array_columns[0]
251-
252-
for orig_idx, group in merged_df.groupby(grouping_col):
253-
array_row_groups[str(orig_idx)] = group.index.tolist()
254-
255-
# Clean up temporary columns
256-
if "_temp_grouping_col" in merged_df.columns:
257-
merged_df = merged_df.drop(columns=["_temp_grouping_col"])
258-
final_cols.remove("_temp_grouping_col")
259-
if "_original_index" in merged_df.columns:
260-
merged_df = merged_df.drop(columns=["_original_index"])
261-
final_cols.remove("_original_index")
262-
if original_index_name:
263-
merged_df = merged_df.set_index(original_index_name)
264-
final_cols.remove(original_index_name)
248+
grouping_col_name = (
249+
"_original_index" if original_index_name is None else original_index_name
250+
)
251+
if grouping_col_name in merged_df.columns:
252+
for orig_idx, group in merged_df.groupby(grouping_col_name):
253+
array_row_groups[str(orig_idx)] = group.index.tolist()
254+
255+
# Restore original columns
256+
result_df = merged_df[original_cols]
265257

266-
final_cols.remove("_row_num")
267-
merged_df = merged_df[final_cols]
258+
if original_index_name:
259+
result_df = result_df.set_index(original_index_name)
268260

269-
return merged_df, array_row_groups
261+
return result_df, array_row_groups
270262

271263

272264
def _flatten_struct_columns(

0 commit comments

Comments
 (0)