@@ -155,38 +155,37 @@ def _classify_columns(
155155 Returns:
156156 A ColumnClassification object containing lists of column names for each category.
157157 """
158- initial_columns = list (dataframe .columns )
159- struct_columns : list [str ] = []
160- array_columns : list [str ] = []
161- array_of_struct_columns : list [str ] = []
162- clear_on_continuation_cols : list [str ] = []
163- nested_originated_columns : set [str ] = set ()
164-
165- for col_name_raw , col_data in dataframe .items ():
166- col_name = str (col_name_raw )
167- dtype = col_data .dtype
168- if isinstance (dtype , pd .ArrowDtype ):
169- pa_type = dtype .pyarrow_dtype
170- if pa .types .is_struct (pa_type ):
171- struct_columns .append (col_name )
172- nested_originated_columns .add (col_name )
173- elif pa .types .is_list (pa_type ):
174- array_columns .append (col_name )
175- nested_originated_columns .add (col_name )
176- if hasattr (pa_type , "value_type" ) and (
177- pa .types .is_struct (pa_type .value_type )
178- ):
179- array_of_struct_columns .append (col_name )
180- else :
181- clear_on_continuation_cols .append (col_name )
182- elif col_name in initial_columns :
183- clear_on_continuation_cols .append (col_name )
158+ # Maps column names to their structural category to simplify list building.
159+ categories : dict [str , str ] = {}
160+
161+ for col , dtype in dataframe .dtypes .items ():
162+ col_name = str (col )
163+ pa_type = getattr (dtype , "pyarrow_dtype" , None )
164+
165+ if not pa_type :
166+ categories [col_name ] = "clear"
167+ elif pa .types .is_struct (pa_type ):
168+ categories [col_name ] = "struct"
169+ elif pa .types .is_list (pa_type ):
170+ is_struct_array = pa .types .is_struct (pa_type .value_type )
171+ categories [col_name ] = "array_of_struct" if is_struct_array else "array"
172+ else :
173+ categories [col_name ] = "clear"
174+
184175 return ColumnClassification (
185- struct_columns = struct_columns ,
186- array_columns = array_columns ,
187- array_of_struct_columns = array_of_struct_columns ,
188- clear_on_continuation_cols = clear_on_continuation_cols ,
189- nested_originated_columns = nested_originated_columns ,
176+ struct_columns = [c for c , cat in categories .items () if cat == "struct" ],
177+ array_columns = [
178+ c for c , cat in categories .items () if cat in ("array" , "array_of_struct" )
179+ ],
180+ array_of_struct_columns = [
181+ c for c , cat in categories .items () if cat == "array_of_struct"
182+ ],
183+ clear_on_continuation_cols = [
184+ c for c , cat in categories .items () if cat == "clear"
185+ ],
186+ nested_originated_columns = {
187+ c for c , cat in categories .items () if cat != "clear"
188+ },
190189 )
191190
192191
@@ -198,10 +197,6 @@ def _flatten_array_of_struct_columns(
198197) -> tuple [pd .DataFrame , list [str ]]:
199198 """Flatten ARRAY of STRUCT columns into separate ARRAY columns for each field.
200199
201- For example, an ARRAY<STRUCT<a INT64, b STRING>> column named 'items' will be
202- converted into two ARRAY columns: 'items.a' (ARRAY<INT64>) and 'items.b' (ARRAY<STRING>).
203- This allows us to treat them as standard ARRAY columns for the subsequent explosion step.
204-
205200 Args:
206201 dataframe: The DataFrame to process.
207202 array_of_struct_columns: List of column names that are ARRAYs of STRUCTs.
@@ -214,56 +209,86 @@ def _flatten_array_of_struct_columns(
214209 result_df = dataframe .copy ()
215210 for col_name in array_of_struct_columns :
216211 col_data = result_df [col_name ]
217- pa_type = cast (pd .ArrowDtype , col_data .dtype ).pyarrow_dtype
218- struct_type = pa_type .value_type
219-
220- # Use PyArrow to reshape the list<struct> into multiple list<field> arrays
212+ # Ensure we have a PyArrow array (pa.array handles pandas Series conversion)
221213 arrow_array = pa .array (col_data )
222- offsets = arrow_array .offsets
223- values = arrow_array .values # StructArray
224- flattened_fields = values .flatten () # List[Array]
225-
226- new_cols_to_add = {}
227- new_array_col_names = []
228214
229- # Create new columns for each struct field
230- for field_idx in range (struct_type .num_fields ):
231- field = struct_type .field (field_idx )
232- new_col_name = f"{ col_name } .{ field .name } "
233- nested_originated_columns .add (new_col_name )
234- new_array_col_names .append (new_col_name )
215+ # Transpose List<Struct<...>> to {field: List<field_type>}
216+ new_arrays = _transpose_list_of_structs (arrow_array )
235217
236- # Reconstruct ListArray for this field. This transforms the
237- # array<struct<f1, f2>> into separate array<f1> and array<f2> columns.
238- new_list_array = pa .ListArray .from_arrays (
239- offsets , flattened_fields [field_idx ], mask = arrow_array .is_null ()
240- )
241-
242- new_cols_to_add [new_col_name ] = pd .Series (
243- new_list_array ,
244- dtype = pd .ArrowDtype (pa .list_ (field .type )),
245- index = result_df .index ,
246- )
218+ new_cols_df = pd .DataFrame (
219+ {
220+ f"{ col_name } .{ field_name } " : pd .Series (
221+ arr , dtype = pd .ArrowDtype (arr .type ), index = result_df .index
222+ )
223+ for field_name , arr in new_arrays .items ()
224+ }
225+ )
247226
248- col_idx = result_df .columns .to_list ().index (col_name )
249- new_cols_df = pd .DataFrame (new_cols_to_add , index = result_df .index )
227+ # Track the new columns
228+ for new_col in new_cols_df .columns :
229+ nested_originated_columns .add (new_col )
250230
251- result_df = pd .concat (
252- [
253- result_df .iloc [:, :col_idx ],
254- new_cols_df ,
255- result_df .iloc [:, col_idx + 1 :],
256- ],
257- axis = 1 ,
258- )
231+ # Update the DataFrame
232+ result_df = _replace_column_in_df (result_df , col_name , new_cols_df )
259233
260234 # Update array_columns list
261235 array_columns .remove (col_name )
262- # Add the new array columns
263- array_columns . extend ( new_array_col_names )
236+ array_columns . extend ( new_cols_df . columns . tolist ())
237+
264238 return result_df , array_columns
265239
266240
241+ def _transpose_list_of_structs (arrow_array : pa .ListArray ) -> dict [str , pa .ListArray ]:
242+ """Transposes a ListArray of Structs into multiple ListArrays of fields.
243+
244+ Args:
245+ arrow_array: A PyArrow ListArray where the value type is a Struct.
246+
247+ Returns:
248+ A dictionary mapping field names to new ListArrays (one for each field in the struct).
249+ """
250+ struct_type = arrow_array .type .value_type
251+ offsets = arrow_array .offsets
252+ # arrow_array.values is the underlying StructArray.
253+ # Flattening it gives us the arrays for each field, effectively "removing" the struct layer.
254+ flattened_fields = arrow_array .values .flatten ()
255+ validity = arrow_array .is_null ()
256+
257+ transposed = {}
258+ for i in range (struct_type .num_fields ):
259+ field = struct_type .field (i )
260+ # Reconstruct ListArray for each field using original offsets and validity.
261+ # This transforms List<Struct<A, B>> into List<A> and List<B>.
262+ transposed [field .name ] = pa .ListArray .from_arrays (
263+ offsets , flattened_fields [i ], mask = validity
264+ )
265+ return transposed
266+
267+
268+ def _replace_column_in_df (
269+ dataframe : pd .DataFrame , col_name : str , new_cols : pd .DataFrame
270+ ) -> pd .DataFrame :
271+ """Replaces a column in a DataFrame with a set of new columns at the same position.
272+
273+ Args:
274+ dataframe: The original DataFrame.
275+ col_name: The name of the column to replace.
276+ new_cols: A DataFrame containing the new columns to insert.
277+
278+ Returns:
279+ A new DataFrame with the substitution made.
280+ """
281+ col_idx = dataframe .columns .to_list ().index (col_name )
282+ return pd .concat (
283+ [
284+ dataframe .iloc [:, :col_idx ],
285+ new_cols ,
286+ dataframe .iloc [:, col_idx + 1 :],
287+ ],
288+ axis = 1 ,
289+ )
290+
291+
267292def _explode_array_columns (
268293 dataframe : pd .DataFrame , array_columns : list [str ]
269294) -> ExplodeResult :
0 commit comments