Skip to content

Commit 41df7b3

Browse files
committed
Fix: Improve performance of nested data flattening
1 parent ce59668 commit 41df7b3

File tree

1 file changed

+75
-74
lines changed

1 file changed

+75
-74
lines changed

bigframes/display/_flatten.py

Lines changed: 75 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -159,14 +159,14 @@ def _flatten_array_of_struct_columns(
159159
nested_originated_columns.add(new_col_name)
160160
new_array_col_names.append(new_col_name)
161161

162-
# Reconstruct ListArray for this field
163-
# Use mask=arrow_array.is_null() to preserve nulls from the original list
162+
# Reconstruct ListArray for this field. This transforms the
163+
# array<struct<f1, f2>> into separate array<f1> and array<f2> columns.
164164
new_list_array = pa.ListArray.from_arrays(
165165
offsets, flattened_fields[field_idx], mask=arrow_array.is_null()
166166
)
167167

168168
new_cols_to_add[new_col_name] = pd.Series(
169-
new_list_array.to_pylist(),
169+
new_list_array,
170170
dtype=pd.ArrowDtype(pa.list_(field.type)),
171171
index=result_df.index,
172172
)
@@ -194,78 +194,79 @@ def _explode_array_columns(
194194
dataframe: pd.DataFrame, array_columns: list[str]
195195
) -> tuple[pd.DataFrame, dict[str, list[int]]]:
196196
"""Explode array columns into new rows."""
197-
exploded_rows = []
198-
array_row_groups: dict[str, list[int]] = {}
197+
if not array_columns:
198+
return dataframe, {}
199+
199200
non_array_columns = dataframe.columns.drop(array_columns).tolist()
200-
non_array_df = dataframe[non_array_columns]
201-
202-
for orig_idx in dataframe.index:
203-
non_array_data = non_array_df.loc[orig_idx].to_dict()
204-
array_values = {}
205-
max_len_in_row = 0
206-
non_na_array_found = False
207-
208-
for col_name in array_columns:
209-
val = dataframe.loc[orig_idx, col_name]
210-
if val is not None and not (
211-
isinstance(val, list) and len(val) == 1 and pd.isna(val[0])
212-
):
213-
array_values[col_name] = list(val)
214-
max_len_in_row = max(max_len_in_row, len(val))
215-
non_na_array_found = True
216-
else:
217-
array_values[col_name] = []
218-
219-
if not non_na_array_found:
220-
new_row = non_array_data.copy()
221-
for col_name in array_columns:
222-
new_row[f"{col_name}"] = pd.NA
223-
exploded_rows.append(new_row)
224-
orig_key = str(orig_idx)
225-
if orig_key not in array_row_groups:
226-
array_row_groups[orig_key] = []
227-
array_row_groups[orig_key].append(len(exploded_rows) - 1)
228-
continue
229-
230-
# Create one row per array element, up to max_len_in_row
231-
for array_idx in range(max_len_in_row):
232-
new_row = non_array_data.copy()
233-
234-
# Add the specific array element for this index
235-
for col_name in array_columns:
236-
if array_idx < len(array_values.get(col_name, [])):
237-
new_row[f"{col_name}"] = array_values[col_name][array_idx]
238-
else:
239-
new_row[f"{col_name}"] = pd.NA
240-
241-
exploded_rows.append(new_row)
242-
243-
# Track which rows belong to which original row
244-
orig_key = str(orig_idx)
245-
if orig_key not in array_row_groups:
246-
array_row_groups[orig_key] = []
247-
array_row_groups[orig_key].append(len(exploded_rows) - 1)
248-
249-
if exploded_rows:
250-
# Reconstruct the DataFrame to maintain original column order
251-
exploded_df = pd.DataFrame(exploded_rows)[dataframe.columns]
252-
for col in exploded_df.columns:
253-
# After explosion, object columns that are all-numeric (except for NAs)
254-
# should be converted to a numeric dtype for proper alignment.
255-
if exploded_df[col].dtype == "object":
256-
try:
257-
# Use nullable integer type to preserve integers
258-
exploded_df[col] = exploded_df[col].astype(pd.Int64Dtype())
259-
except (ValueError, TypeError):
260-
# Fallback for non-integer numerics
261-
try:
262-
exploded_df[col] = pd.to_numeric(exploded_df[col])
263-
except (ValueError, TypeError):
264-
# Keep as object if not numeric
265-
pass
266-
return exploded_df, array_row_groups
201+
if not non_array_columns:
202+
# Add a temporary column to allow grouping if all columns are arrays
203+
non_array_columns = ["_temp_grouping_col"]
204+
dataframe["_temp_grouping_col"] = range(len(dataframe))
205+
206+
# Preserve original index
207+
if dataframe.index.name:
208+
original_index_name = dataframe.index.name
209+
dataframe = dataframe.reset_index()
210+
non_array_columns.append(original_index_name)
211+
else:
212+
original_index_name = None
213+
dataframe = dataframe.reset_index(names=["_original_index"])
214+
non_array_columns.append("_original_index")
215+
216+
exploded_dfs = []
217+
for col in array_columns:
218+
# Explode each array column individually
219+
exploded = dataframe[non_array_columns + [col]].explode(col)
220+
exploded["_row_num"] = exploded.groupby(non_array_columns).cumcount()
221+
exploded_dfs.append(exploded)
222+
223+
if not exploded_dfs:
224+
return dataframe, {}
225+
226+
# Merge the exploded columns
227+
merged_df = exploded_dfs[0]
228+
for i in range(1, len(exploded_dfs)):
229+
merged_df = pd.merge(
230+
merged_df,
231+
exploded_dfs[i],
232+
on=non_array_columns + ["_row_num"],
233+
how="outer",
234+
)
235+
236+
# Restore original column order and sort
237+
final_cols = dataframe.columns.tolist() + ["_row_num"]
238+
merged_df = merged_df.sort_values(non_array_columns + ["_row_num"]).reset_index(
239+
drop=True
240+
)
241+
242+
# Create row groups
243+
array_row_groups = {}
244+
if "_original_index" in merged_df.columns:
245+
grouping_col = "_original_index"
246+
elif original_index_name:
247+
grouping_col = original_index_name
267248
else:
268-
return dataframe, array_row_groups
249+
# Fallback if no clear grouping column is identified
250+
grouping_col = non_array_columns[0]
251+
252+
for orig_idx, group in merged_df.groupby(grouping_col):
253+
array_row_groups[str(orig_idx)] = group.index.tolist()
254+
255+
# Clean up temporary columns
256+
if "_temp_grouping_col" in merged_df.columns:
257+
merged_df = merged_df.drop(columns=["_temp_grouping_col"])
258+
final_cols.remove("_temp_grouping_col")
259+
if "_original_index" in merged_df.columns:
260+
merged_df = merged_df.drop(columns=["_original_index"])
261+
final_cols.remove("_original_index")
262+
if original_index_name:
263+
merged_df = merged_df.set_index(original_index_name)
264+
final_cols.remove(original_index_name)
265+
266+
final_cols.remove("_row_num")
267+
merged_df = merged_df[final_cols]
268+
269+
return merged_df, array_row_groups
269270

270271

271272
def _flatten_struct_columns(
@@ -295,7 +296,7 @@ def _flatten_struct_columns(
295296

296297
# Create a new Series from the flattened array
297298
new_cols_to_add[new_col_name] = pd.Series(
298-
flattened_fields[field_idx].to_pylist(),
299+
flattened_fields[field_idx],
299300
dtype=pd.ArrowDtype(field.type),
300301
index=result_df.index,
301302
)

0 commit comments

Comments
 (0)