Skip to content

Commit 0f48f82

Browse files
committed
refactor: improve _flatten readability and table widget styles
1 parent 4d46e3c commit 0f48f82

File tree

1 file changed

+117
-37
lines changed

1 file changed

+117
-37
lines changed

bigframes/display/_flatten.py

Lines changed: 117 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,15 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
"""Utilities for flattening nested data structures for display."""
15+
"""Utilities for flattening nested data structures for display.
16+
17+
This module provides functionality to flatten BigQuery STRUCT and ARRAY columns
18+
in a pandas DataFrame into a format suitable for display in a 2D table widget.
19+
It handles nested structures by:
20+
1. Expanding STRUCT fields into separate columns (e.g., "struct.field").
21+
2. Exploding ARRAY elements into multiple rows, replicating other columns.
22+
3. Generating metadata to grouping rows and handling continuation values.
23+
"""
1624

1725
from __future__ import annotations
1826

@@ -25,62 +33,75 @@
2533

2634
@dataclasses.dataclass(frozen=True)
2735
class FlattenResult:
28-
"""The result of flattening a DataFrame."""
36+
"""The result of flattening a DataFrame.
2937
30-
dataframe: pd.DataFrame
31-
"""The flattened DataFrame."""
38+
Attributes:
39+
dataframe: The flattened DataFrame.
40+
row_labels: A list of original row labels for each row in the flattened DataFrame.
41+
continuation_rows: A set of row indices that are continuation rows.
42+
cleared_on_continuation: A list of column names that should be cleared on continuation rows.
43+
nested_columns: A set of column names that were created from nested data.
44+
"""
3245

46+
dataframe: pd.DataFrame
3347
row_labels: list[str] | None
34-
"""A list of original row labels for each row in the flattened DataFrame."""
35-
3648
continuation_rows: set[int] | None
37-
"""A set of row indices that are continuation rows."""
38-
3949
cleared_on_continuation: list[str]
40-
"""A list of column names that should be cleared on continuation rows."""
41-
4250
nested_columns: set[str]
43-
"""A set of column names that were created from nested data."""
4451

4552

4653
@dataclasses.dataclass(frozen=True)
4754
class ColumnClassification:
48-
"""The result of classifying columns."""
55+
"""The result of classifying columns.
4956
50-
struct_columns: list[str]
51-
"""Columns that are STRUCTs."""
57+
Attributes:
58+
struct_columns: Columns that are STRUCTs.
59+
array_columns: Columns that are ARRAYs.
60+
array_of_struct_columns: Columns that are ARRAYs of STRUCTs.
61+
clear_on_continuation_cols: Columns that should be cleared on continuation rows.
62+
nested_originated_columns: Columns that were created from nested data.
63+
"""
5264

65+
struct_columns: list[str]
5366
array_columns: list[str]
54-
"""Columns that are ARRAYs."""
55-
5667
array_of_struct_columns: list[str]
57-
"""Columns that are ARRAYs of STRUCTs."""
58-
5968
clear_on_continuation_cols: list[str]
60-
"""Columns that should be cleared on continuation rows."""
61-
6269
nested_originated_columns: set[str]
63-
"""Columns that were created from nested data."""
6470

6571

6672
@dataclasses.dataclass(frozen=True)
6773
class ExplodeResult:
68-
"""The result of exploding array columns."""
74+
"""The result of exploding array columns.
6975
70-
dataframe: pd.DataFrame
71-
"""The exploded DataFrame."""
76+
Attributes:
77+
dataframe: The exploded DataFrame.
78+
row_labels: Labels for the rows.
79+
continuation_rows: Indices of continuation rows.
80+
"""
7281

82+
dataframe: pd.DataFrame
7383
row_labels: list[str]
74-
"""Labels for the rows."""
75-
7684
continuation_rows: set[int]
77-
"""Indices of continuation rows."""
7885

7986

8087
def flatten_nested_data(
8188
dataframe: pd.DataFrame,
8289
) -> FlattenResult:
83-
"""Flatten nested STRUCT and ARRAY columns for display."""
90+
"""Flatten nested STRUCT and ARRAY columns for display.
91+
92+
This function coordinates the flattening process:
93+
1. Classifies columns into STRUCT, ARRAY, ARRAY-of-STRUCT, and standard types.
94+
2. Flattens ARRAY-of-STRUCT columns into multiple ARRAY columns (one per struct field).
95+
This simplifies the subsequent explosion step.
96+
3. Flattens top-level STRUCT columns into separate columns.
97+
4. Explodes all ARRAY columns (original and those from step 2) into multiple rows.
98+
99+
Args:
100+
dataframe: The input DataFrame containing potential nested structures.
101+
102+
Returns:
103+
A FlattenResult containing the flattened DataFrame and metadata for display.
104+
"""
84105
if dataframe.empty:
85106
return FlattenResult(
86107
dataframe=dataframe.copy(),
@@ -93,10 +114,9 @@ def flatten_nested_data(
93114
result_df = dataframe.copy()
94115

95116
classification = _classify_columns(result_df)
96-
# Extract lists to allow modification
97-
# TODO(b/469966526): The modification of these lists in place by subsequent functions
98-
# (e.g. _flatten_array_of_struct_columns removing items from array_columns) suggests
99-
# that the data flow here could be cleaner, but keeping it as is for now.
117+
# Extract lists to allow modification by subsequent steps.
118+
# _flatten_array_of_struct_columns will modify array_columns to replace
119+
# the original array-of-struct column with the new flattened array columns.
100120
struct_columns = classification.struct_columns
101121
array_columns = classification.array_columns
102122
array_of_struct_columns = classification.array_of_struct_columns
@@ -134,7 +154,17 @@ def flatten_nested_data(
134154
def _classify_columns(
135155
dataframe: pd.DataFrame,
136156
) -> ColumnClassification:
137-
"""Identify all STRUCT and ARRAY columns."""
157+
"""Identify all STRUCT and ARRAY columns in the DataFrame.
158+
159+
It inspects the PyArrow dtype of each column to determine if it is a
160+
STRUCT, LIST (Array), or LIST of STRUCTs.
161+
162+
Args:
163+
dataframe: The DataFrame to inspect.
164+
165+
Returns:
166+
A ColumnClassification object containing lists of column names for each category.
167+
"""
138168
initial_columns = list(dataframe.columns)
139169
struct_columns: list[str] = []
140170
array_columns: list[str] = []
@@ -176,7 +206,21 @@ def _flatten_array_of_struct_columns(
176206
array_columns: list[str],
177207
nested_originated_columns: set[str],
178208
) -> tuple[pd.DataFrame, list[str]]:
179-
"""Flatten ARRAY of STRUCT columns into separate array columns for each field."""
209+
"""Flatten ARRAY of STRUCT columns into separate ARRAY columns for each field.
210+
211+
For example, an ARRAY<STRUCT<a INT64, b STRING>> column named 'items' will be
212+
converted into two ARRAY columns: 'items.a' (ARRAY<INT64>) and 'items.b' (ARRAY<STRING>).
213+
This allows us to treat them as standard ARRAY columns for the subsequent explosion step.
214+
215+
Args:
216+
dataframe: The DataFrame to process.
217+
array_of_struct_columns: List of column names that are ARRAYs of STRUCTs.
218+
array_columns: The main list of ARRAY columns to be updated.
219+
nested_originated_columns: Set of columns tracked as originating from nested data.
220+
221+
Returns:
222+
A tuple containing the modified DataFrame and the updated list of array columns.
223+
"""
180224
result_df = dataframe.copy()
181225
for col_name in array_of_struct_columns:
182226
col_data = result_df[col_name]
@@ -233,7 +277,26 @@ def _flatten_array_of_struct_columns(
233277
def _explode_array_columns(
234278
dataframe: pd.DataFrame, array_columns: list[str]
235279
) -> ExplodeResult:
236-
"""Explode array columns into new rows."""
280+
"""Explode array columns into new rows.
281+
282+
This function performs the "flattening" of 1D arrays by exploding them.
283+
It handles multiple array columns by ensuring they are exploded in sync
284+
relative to the other columns.
285+
286+
Design details:
287+
- We group by all non-array columns to maintain context.
288+
- `_row_num` is used to track the index within the exploded array, effectively
289+
synchronizing multiple arrays if they belong to the same row.
290+
- Continuation rows (index > 0 in the explosion) are tracked so we can clear
291+
repeated values in the display.
292+
293+
Args:
294+
dataframe: The DataFrame to explode.
295+
array_columns: List of array columns to explode.
296+
297+
Returns:
298+
An ExplodeResult containing the new DataFrame and row metadata.
299+
"""
237300
if not array_columns:
238301
return ExplodeResult(dataframe, [], set())
239302

@@ -243,7 +306,8 @@ def _explode_array_columns(
243306
non_array_columns = work_df.columns.drop(array_columns).tolist()
244307
if not non_array_columns:
245308
work_df = work_df.copy() # Avoid modifying input
246-
# Add a temporary column to allow grouping if all columns are arrays
309+
# Add a temporary column to allow grouping if all columns are arrays.
310+
# This ensures we can still group by "original row" even if there are no scalar columns.
247311
non_array_columns = ["_temp_grouping_col"]
248312
work_df["_temp_grouping_col"] = range(len(work_df))
249313

@@ -278,6 +342,7 @@ def _explode_array_columns(
278342
# Re-cast to arrow dtype if possible
279343
exploded[col] = exploded[col].astype(target_dtype)
280344

345+
# Track position in the array for alignment
281346
exploded["_row_num"] = exploded.groupby(non_array_columns).cumcount()
282347
exploded_dfs.append(exploded)
283348

@@ -322,7 +387,22 @@ def _flatten_struct_columns(
322387
clear_on_continuation_cols: list[str],
323388
nested_originated_columns: set[str],
324389
) -> tuple[pd.DataFrame, list[str]]:
325-
"""Flatten regular STRUCT columns."""
390+
"""Flatten regular STRUCT columns into separate columns.
391+
392+
A STRUCT column 'user' with fields 'name' and 'age' becomes 'user.name'
393+
and 'user.age'.
394+
395+
Args:
396+
dataframe: The DataFrame to process.
397+
struct_columns: List of STRUCT columns to flatten.
398+
clear_on_continuation_cols: List of columns to clear on continuation,
399+
which will be updated with the new flattened columns.
400+
nested_originated_columns: Set of columns tracked as originating from nested data.
401+
402+
Returns:
403+
A tuple containing the modified DataFrame and the updated list of
404+
columns to clear on continuation.
405+
"""
326406
result_df = dataframe.copy()
327407
for col_name in struct_columns:
328408
col_data = result_df[col_name]

0 commit comments

Comments
 (0)