diff --git a/python/pyspark/errors/error-conditions.json b/python/pyspark/errors/error-conditions.json index ee35e237b8983..d256e786e00cb 100644 --- a/python/pyspark/errors/error-conditions.json +++ b/python/pyspark/errors/error-conditions.json @@ -380,6 +380,11 @@ " index out of range, got ''." ] }, + "INVALID_ARROW_BATCH_ZIP": { + "message": [ + "Cannot zip Arrow batches/arrays: ." + ] + }, "INVALID_ARROW_UDTF_RETURN_TYPE": { "message": [ "The return type of the arrow-optimized Python UDTF should be of type 'pandas.DataFrame', but the '' method returned a value of type with value: ." diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index 2288677763d3f..2332bfe49ce23 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -69,7 +69,7 @@ from pyspark.sql.connect.readwriter import DataFrameReader from pyspark.sql.connect.streaming.readwriter import DataStreamReader from pyspark.sql.connect.streaming.query import StreamingQueryManager -from pyspark.sql.pandas.serializers import ArrowStreamPandasSerializer +from pyspark.sql.conversion import PandasBatchTransformer from pyspark.sql.pandas.types import ( to_arrow_schema, to_arrow_type, @@ -630,15 +630,15 @@ def createDataFrame( safecheck = configs["spark.sql.execution.pandas.convertToArrowArraySafely"] - ser = ArrowStreamPandasSerializer(cast(str, timezone), safecheck == "true", False) - _table = pa.Table.from_batches( [ - ser._create_batch( + PandasBatchTransformer.to_arrow( [ (c, at, st) for (_, c), at, st in zip(data.items(), arrow_types, spark_types) - ] + ], + timezone=cast(str, timezone), + safecheck=safecheck == "true", ) ] ) diff --git a/python/pyspark/sql/conversion.py b/python/pyspark/sql/conversion.py index fdcb29b54412e..dbb6eda069a38 100644 --- a/python/pyspark/sql/conversion.py +++ b/python/pyspark/sql/conversion.py @@ -18,13 +18,26 @@ import array import datetime import decimal -from typing import TYPE_CHECKING, Any, Callable, List, Optional, Sequence, Union, overload +from typing import ( + TYPE_CHECKING, + Any, + Callable, + List, + Optional, + Sequence, + Tuple, + Union, + cast, + overload, +) -from pyspark.errors import PySparkValueError +from pyspark.errors import PySparkValueError, PySparkTypeError, PySparkRuntimeError from pyspark.sql.pandas.types import ( _dedup_names, _deduplicate_field_names, _create_converter_to_pandas, + _create_converter_from_pandas, + from_arrow_type, to_arrow_schema, ) from pyspark.sql.pandas.utils import require_minimum_pyarrow_version @@ -58,8 +71,13 @@ class ArrowBatchTransformer: """ - Pure functions that transform RecordBatch -> RecordBatch. + Pure functions that transform Arrow data structures (Arrays, RecordBatches). They should have no side effects (no I/O, no writing to streams). + + This class provides utility methods for Arrow batch transformations used throughout + PySpark's Arrow UDF implementation. All methods are static and handle common patterns + like struct wrapping/unwrapping, schema conversions, and creating RecordBatches from Arrays. + """ @staticmethod @@ -67,10 +85,14 @@ def flatten_struct(batch: "pa.RecordBatch", column_index: int = 0) -> "pa.Record """ Flatten a struct column at given index into a RecordBatch. - Used by: - - ArrowStreamUDFSerializer.load_stream - - SQL_GROUPED_MAP_ARROW_UDF mapper - - SQL_GROUPED_MAP_ARROW_ITER_UDF mapper + Used by + ------- + - ArrowStreamGroupSerializer + - SQL_ARROW_UDTF mapper + - SQL_MAP_ARROW_ITER_UDF mapper + - SQL_GROUPED_MAP_ARROW_UDF mapper + - SQL_GROUPED_MAP_ARROW_ITER_UDF mapper + - SQL_COGROUPED_MAP_ARROW_UDF mapper """ import pyarrow as pa @@ -82,7 +104,14 @@ def wrap_struct(batch: "pa.RecordBatch") -> "pa.RecordBatch": """ Wrap a RecordBatch's columns into a single struct column. - Used by: ArrowStreamUDFSerializer.dump_stream + Used by + ------- + - wrap_grouped_map_arrow_udf + - wrap_grouped_map_arrow_iter_udf + - wrap_cogrouped_map_arrow_udf + - wrap_arrow_batch_iter_udf + - SQL_ARROW_UDTF mapper + - TransformWithStateInPySparkRowSerializer.dump_stream """ import pyarrow as pa @@ -94,6 +123,509 @@ def wrap_struct(batch: "pa.RecordBatch") -> "pa.RecordBatch": struct = pa.StructArray.from_arrays(batch.columns, fields=pa.struct(list(batch.schema))) return pa.RecordBatch.from_arrays([struct], ["_0"]) + @classmethod + def enforce_schema( + cls, + batch: "pa.RecordBatch", + return_type: "StructType", + timezone: str = "UTC", + prefer_large_var_types: bool = False, + safecheck: bool = True, + ) -> "pa.RecordBatch": + """ + Enforce target schema on a RecordBatch by reordering columns and coercing types. + + This method: + 1. Reorders columns to match the target schema field order by name + 2. Casts column types to match target schema types + + Parameters + ---------- + batch : pa.RecordBatch + Input RecordBatch to transform + return_type : pyspark.sql.types.StructType + Target Spark schema to enforce. + timezone : str, default "UTC" + Timezone for timestamp type conversion. + prefer_large_var_types : bool, default False + If True, use large variable types (large_string, large_binary) in Arrow. + safecheck : bool, default True + If True, use safe casting (fails on overflow/truncation). + + Returns + ------- + pa.RecordBatch + RecordBatch with columns reordered and types coerced to match target schema + + Used by + ------- + - wrap_grouped_map_arrow_udf + - wrap_grouped_map_arrow_iter_udf + - wrap_cogrouped_map_arrow_udf + - SQL_ARROW_UDTF mapper + """ + import pyarrow as pa + from pyspark.sql.pandas.types import to_arrow_schema + + # Handle empty batch case (no columns) or empty struct case + if batch.num_columns == 0: + return batch + + if len(return_type) == 0: + return batch + + # Convert Spark StructType to PyArrow schema + arrow_schema = to_arrow_schema( + return_type, timezone=timezone, prefers_large_types=prefer_large_var_types + ) + + target_field_names = [field.name for field in arrow_schema] + + # Reorder columns by name and coerce types + coerced_arrays = [] + for field in arrow_schema: + arr = batch.column(field.name) + if arr.type == field.type: + coerced_arrays.append(arr) + else: + try: + coerced_arrays.append(arr.cast(target_type=field.type, safe=safecheck)) + except (pa.ArrowInvalid, pa.ArrowTypeError): + raise PySparkRuntimeError( + errorClass="RESULT_COLUMNS_MISMATCH_FOR_ARROW_UDTF", + messageParameters={ + "expected": str(field.type), + "actual": str(arr.type), + }, + ) + + return pa.RecordBatch.from_arrays(coerced_arrays, names=target_field_names) + + @classmethod + def concat_batches(cls, batches: List["pa.RecordBatch"]) -> "pa.RecordBatch": + """ + Concatenate multiple RecordBatches into a single RecordBatch. + + This method handles both modern and legacy PyArrow versions. + + Parameters + ---------- + batches : List[pa.RecordBatch] + List of RecordBatches with the same schema + + Returns + ------- + pa.RecordBatch + Single RecordBatch containing all rows from input batches + + Used by + ------- + - SQL_GROUPED_AGG_ARROW_UDF mapper + - SQL_WINDOW_AGG_ARROW_UDF mapper + + Examples + -------- + >>> import pyarrow as pa + >>> batch1 = pa.RecordBatch.from_arrays([pa.array([1, 2])], ['a']) + >>> batch2 = pa.RecordBatch.from_arrays([pa.array([3, 4])], ['a']) + >>> result = ArrowBatchTransformer.concat_batches([batch1, batch2]) + >>> result.to_pydict() + {'a': [1, 2, 3, 4]} + """ + import pyarrow as pa + + if not batches: + raise PySparkValueError( + errorClass="INVALID_ARROW_BATCH_CONCAT", + messageParameters={"reason": "Cannot concatenate empty list of batches"}, + ) + + # Assert all batches have the same schema + first_schema = batches[0].schema + for i, batch in enumerate(batches[1:], start=1): + if batch.schema != first_schema: + raise PySparkValueError( + errorClass="INVALID_ARROW_BATCH_CONCAT", + messageParameters={ + "reason": ( + f"All batches must have the same schema. " + f"Batch 0 has schema {first_schema}, but batch {i} has schema {batch.schema}." + ) + }, + ) + + if hasattr(pa, "concat_batches"): + return pa.concat_batches(batches) + else: + # pyarrow.concat_batches not supported in old versions + return pa.RecordBatch.from_struct_array( + pa.concat_arrays([b.to_struct_array() for b in batches]) + ) + + @classmethod + def zip_batches( + cls, + items: Union[ + List["pa.RecordBatch"], + List["pa.Array"], + List[Tuple["pa.Array", "pa.DataType"]], + ], + safecheck: bool = True, + ) -> "pa.RecordBatch": + """ + Zip multiple RecordBatches or Arrays horizontally by combining their columns. + + This is different from concat_batches which concatenates rows vertically. + This method combines columns from multiple batches/arrays into a single batch, + useful when multiple UDFs each produce a RecordBatch or when combining arrays. + + Parameters + ---------- + items : List[pa.RecordBatch], List[pa.Array], or List[Tuple[pa.Array, pa.DataType]] + - List of RecordBatches to zip (must have same number of rows) + - List of Arrays to combine directly + - List of (array, type) tuples for type casting (always attempts cast if types don't match) + safecheck : bool, default True + If True, use safe casting (fails on overflow/truncation) (only used when items are tuples). + + Returns + ------- + pa.RecordBatch + Single RecordBatch with all columns from input batches/arrays + + Used by + ------- + - SQL_GROUPED_AGG_ARROW_UDF mapper + - SQL_WINDOW_AGG_ARROW_UDF mapper + - wrap_scalar_arrow_udf + - wrap_grouped_agg_arrow_udf + - ArrowBatchUDFSerializer.dump_stream + + Examples + -------- + >>> import pyarrow as pa + >>> batch1 = pa.RecordBatch.from_arrays([pa.array([1, 2])], ['a']) + >>> batch2 = pa.RecordBatch.from_arrays([pa.array([3, 4])], ['b']) + >>> result = ArrowBatchTransformer.zip_batches([batch1, batch2]) + >>> result.to_pydict() + {'_0': [1, 2], '_1': [3, 4]} + >>> # Can also zip arrays directly + >>> result = ArrowBatchTransformer.zip_batches([pa.array([1, 2]), pa.array([3, 4])]) + >>> result.to_pydict() + {'_0': [1, 2], '_1': [3, 4]} + >>> # Can also zip with type casting + >>> result = ArrowBatchTransformer.zip_batches( + ... [(pa.array([1, 2]), pa.int64()), (pa.array([3, 4]), pa.int64())] + ... ) + """ + import pyarrow as pa + + if not items: + raise PySparkValueError( + errorClass="INVALID_ARROW_BATCH_ZIP", + messageParameters={"reason": "Cannot zip empty list"}, + ) + + # Check if items are RecordBatches, Arrays, or (array, type) tuples + first_item = items[0] + + if isinstance(first_item, pa.RecordBatch): + # Handle RecordBatches + batches = cast(List["pa.RecordBatch"], items) + if len(batches) == 1: + return batches[0] + + # Assert all batches have the same number of rows + num_rows = batches[0].num_rows + for i, batch in enumerate(batches[1:], start=1): + assert batch.num_rows == num_rows, ( + f"All batches must have the same number of rows. " + f"Batch 0 has {num_rows} rows, but batch {i} has {batch.num_rows} rows." + ) + + # Combine all columns from all batches + all_columns = [] + for batch in batches: + all_columns.extend(batch.columns) + elif isinstance(first_item, tuple) and len(first_item) == 2: + # Handle (array, type) tuples with type casting (always attempt cast if types don't match) + all_columns = [ + cls._cast_array( + array, + arrow_type, + safecheck=safecheck, + error_message=( + "Arrow UDFs require the return type to match the expected Arrow type. " + f"Expected: {arrow_type}, but got: {array.type}." + ), + ) + for array, arrow_type in items + ] + else: + # Handle Arrays directly + all_columns = list(items) + + # Create RecordBatch from columns + return pa.RecordBatch.from_arrays(all_columns, ["_%d" % i for i in range(len(all_columns))]) + + @classmethod + def to_pandas( + cls, + batch: Union["pa.RecordBatch", "pa.Table"], + timezone: str, + schema: Optional["StructType"] = None, + struct_in_pandas: str = "dict", + ndarray_as_list: bool = False, + df_for_struct: bool = False, + ) -> List[Union["pd.Series", "pd.DataFrame"]]: + """ + Convert a RecordBatch or Table to a list of pandas Series. + + Parameters + ---------- + batch : pa.RecordBatch or pa.Table + The Arrow RecordBatch or Table to convert. + timezone : str + Timezone for timestamp conversion. + schema : StructType, optional + Spark schema for type conversion. If None, types are inferred from Arrow. + struct_in_pandas : str + How to represent struct in pandas ("dict", "row", etc.) + ndarray_as_list : bool + Whether to convert ndarray as list. + df_for_struct : bool + If True, convert struct columns to DataFrame instead of Series. + + Returns + ------- + List[Union[pd.Series, pd.DataFrame]] + List of pandas Series (or DataFrame if df_for_struct=True), one for each column. + """ + import pandas as pd + + import pyspark + from pyspark.sql.pandas.types import from_arrow_type + + if batch.num_columns == 0: + return [pd.Series([pyspark._NoValue] * batch.num_rows)] + + return [ + ArrowArrayToPandasConversion.convert( + batch.column(i), + schema[i].dataType if schema is not None else from_arrow_type(batch.column(i).type), + timezone=timezone, + struct_in_pandas=struct_in_pandas, + ndarray_as_list=ndarray_as_list, + df_for_struct=df_for_struct, + ) + for i in range(batch.num_columns) + ] + + @classmethod + def _cast_array( + cls, + arr: "pa.Array", + target_type: "pa.DataType", + safecheck: bool = True, + error_message: Optional[str] = None, + ) -> "pa.Array": + """ + Cast an Arrow Array to a target type with type checking. + + This is a private method used internally by zip_batches. + + Parameters + ---------- + arr : pa.Array + The Arrow Array to cast. + target_type : pa.DataType + Target Arrow data type. + safecheck : bool + If True, use safe casting (fails on overflow/truncation). + error_message : str, optional + Custom error message for type mismatch (used if cast fails). + + Returns + ------- + pa.Array + The casted array if types differ, or original array if types match. + """ + import pyarrow as pa + + assert isinstance(arr, pa.Array) + assert isinstance(target_type, pa.DataType) + + if arr.type == target_type: + return arr + + try: + return arr.cast(target_type=target_type, safe=safecheck) + except (pa.ArrowInvalid, pa.ArrowNotImplementedError) as e: + if error_message: + raise PySparkTypeError(error_message) from e + else: + raise PySparkTypeError( + f"Arrow type mismatch. Expected: {target_type}, but got: {arr.type}." + ) from e + + +class PandasBatchTransformer: + """ + Pure functions that transform between pandas DataFrames/Series and Arrow RecordBatches. + They should have no side effects (no I/O, no writing to streams). + + This class provides utility methods for converting between pandas and Arrow formats, + used primarily by Pandas UDF wrappers and serializers. + + """ + + @classmethod + def to_arrow( + cls, + series: Union["pd.Series", "pd.DataFrame", List], + timezone: str, + safecheck: bool = True, + int_to_decimal_coercion_enabled: bool = False, + as_struct: bool = False, + struct_in_pandas: Optional[str] = None, + ignore_unexpected_complex_type_values: bool = False, + error_class: Optional[str] = None, + assign_cols_by_name: bool = True, + arrow_cast: bool = False, + ) -> "pa.RecordBatch": + """ + Convert pandas.Series/DataFrame or list to Arrow RecordBatch. + + Parameters + ---------- + series : pandas.Series, pandas.DataFrame or list + A single series/dataframe, list of series/dataframes, or list of + (data, arrow_type) or (data, arrow_type, spark_type) + timezone : str + Timezone for timestamp conversion + safecheck : bool + Whether to perform safe type checking + int_to_decimal_coercion_enabled : bool + Whether to enable int to decimal coercion + as_struct : bool + If True, treat all inputs as DataFrames and create struct arrays (for UDTF) + struct_in_pandas : str, optional + How struct types are represented. If "dict", struct types require DataFrame input. + ignore_unexpected_complex_type_values : bool + Whether to ignore unexpected complex type values (for UDTF) + error_class : str, optional + Custom error class for type cast errors (for UDTF) + assign_cols_by_name : bool + If True, assign columns by name; otherwise by position (for struct) + arrow_cast : bool + Whether to apply Arrow casting when type mismatches + + Returns + ------- + pyarrow.RecordBatch + Arrow RecordBatch + """ + import pandas as pd + import pyarrow as pa + from pyspark.sql.pandas.types import is_variant + + # Normalize input to a consistent format + # Make input conform to [(series1, arrow_type1, spark_type1), (series2, arrow_type2, spark_type2), ...] + if ( + not isinstance(series, (list, tuple)) + or (len(series) == 2 and isinstance(series[1], pa.DataType)) + or ( + len(series) == 3 + and isinstance(series[1], pa.DataType) + and isinstance(series[2], DataType) + ) + ): + items: List[Any] = [series] + else: + items = list(series) + tupled = ((s, None) if not isinstance(s, (list, tuple)) else s for s in items) + normalized_series = ((s[0], s[1], None) if len(s) == 2 else s for s in tupled) + + arrs = [] + for s, arrow_type, spark_type in normalized_series: + if as_struct or ( + struct_in_pandas == "dict" + and arrow_type is not None + and pa.types.is_struct(arrow_type) + and not is_variant(arrow_type) + ): + # Struct mode: require DataFrame, create struct array + if not isinstance(s, pd.DataFrame): + error_msg = ( + "Output of an arrow-optimized Python UDTFs expects " + if as_struct + else "Invalid return type. Please make sure that the UDF returns a " + ) + if not as_struct: + error_msg += ( + "pandas.DataFrame when the specified return type is StructType." + ) + else: + error_msg += f"a pandas.DataFrame but got: {type(s)}" + raise PySparkValueError(error_msg) + + # Create struct array from DataFrame + if len(s.columns) == 0: + struct_arr = pa.array([{}] * len(s), arrow_type) + else: + # Determine column selection strategy + use_name_matching = assign_cols_by_name and any( + isinstance(name, str) for name in s.columns + ) + + struct_arrs = [] + for i, field in enumerate(arrow_type): + # Get Series and spark_type based on matching strategy + if use_name_matching: + col_series: "pd.Series" = s[field.name] + field_spark_type = ( + spark_type[field.name].dataType if spark_type is not None else None + ) + else: + col_series = s[s.columns[i]].rename(field.name) + field_spark_type = ( + spark_type[i].dataType if spark_type is not None else None + ) + + struct_arrs.append( + PandasSeriesToArrowConversion.create_array( + col_series, + field.type, + timezone=timezone, + safecheck=safecheck, + spark_type=field_spark_type, + arrow_cast=arrow_cast, + int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled, + ignore_unexpected_complex_type_values=ignore_unexpected_complex_type_values, + error_class=error_class, + ) + ) + + struct_arr = pa.StructArray.from_arrays(struct_arrs, fields=list(arrow_type)) + + arrs.append(struct_arr) + else: + # Normal mode: create array from Series + arrs.append( + PandasSeriesToArrowConversion.create_array( + s, + arrow_type, + timezone=timezone, + safecheck=safecheck, + spark_type=spark_type, + int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled, + arrow_cast=arrow_cast, + ignore_unexpected_complex_type_values=ignore_unexpected_complex_type_values, + error_class=error_class, + ) + ) + return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in range(len(arrs))]) + class LocalDataToArrowConversion: """ @@ -484,15 +1016,67 @@ def convert_other(value: Any) -> Any: else: # pragma: no cover assert False, f"Need converter for {dataType} but failed to find one." + @staticmethod + def create_array( + results: Sequence[Any], + arrow_type: "pa.DataType", + spark_type: DataType, + safecheck: bool = True, + int_to_decimal_coercion_enabled: bool = False, + ) -> "pa.Array": + """ + Create an Arrow Array from a sequence of Python values. + + Parameters + ---------- + results : Sequence[Any] + Sequence of Python values to convert + arrow_type : pa.DataType + Target Arrow data type + spark_type : DataType + Spark data type for conversion + safecheck : bool + If True, use safe casting (fails on overflow/truncation) + int_to_decimal_coercion_enabled : bool + If True, applies additional coercions in Python before converting to Arrow + + Returns + ------- + pa.Array + Arrow Array with converted values + """ + import pyarrow as pa + + conv = LocalDataToArrowConversion._create_converter( + spark_type, + none_on_identity=True, + int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled, + ) + converted = [conv(res) for res in results] if conv is not None else results + try: + return pa.array(converted, type=arrow_type) + except pa.lib.ArrowInvalid: + return pa.array(converted).cast(target_type=arrow_type, safe=safecheck) + @staticmethod def convert(data: Sequence[Any], schema: StructType, use_large_var_types: bool) -> "pa.Table": require_minimum_pyarrow_version() import pyarrow as pa + from pyspark.sql.pandas.types import to_arrow_type - assert isinstance(data, list) and len(data) > 0 - + assert isinstance(data, list) assert schema is not None and isinstance(schema, StructType) + # Handle empty data - return empty table with correct schema + if len(data) == 0: + arrow_schema = to_arrow_type( + schema, timezone="UTC", prefers_large_types=use_large_var_types + ) + pa_schema = pa.schema(list(arrow_schema)) + # Create empty arrays for each field to ensure proper table structure + empty_arrays = [pa.array([], type=f.type) for f in pa_schema] + return pa.Table.from_arrays(empty_arrays, schema=pa_schema) + column_names = schema.fieldNames() len_column_names = len(column_names) @@ -1001,7 +1585,147 @@ def localize_tz(arr: "pa.Array") -> "pa.Array": assert False, f"Need converter for {pa_type} but failed to find one." +class PandasSeriesToArrowConversion: + """ + Conversion utilities for converting pandas Series to PyArrow Arrays. + + This class provides methods to convert pandas Series to PyArrow Arrays, + with support for Spark-specific type handling and conversions. + + The class is primarily used by PySpark's Pandas UDF wrappers and serializers, + where pandas data needs to be converted to Arrow for efficient serialization. + """ + + @classmethod + def create_array( + cls, + series: "pd.Series", + arrow_type: Optional["pa.DataType"], + timezone: str, + safecheck: bool = True, + spark_type: Optional[DataType] = None, + arrow_cast: bool = False, + int_to_decimal_coercion_enabled: bool = False, + ignore_unexpected_complex_type_values: bool = False, + error_class: Optional[str] = None, + ) -> "pa.Array": + """ + Create an Arrow Array from the given pandas.Series and optional type. + + Parameters + ---------- + series : pandas.Series + A single series + arrow_type : pyarrow.DataType, optional + If None, pyarrow's inferred type will be used + timezone : str + Timezone for timestamp conversion + safecheck : bool + Whether to perform safe type checking + spark_type : DataType, optional + If None, spark type converted from arrow_type will be used + arrow_cast : bool + Whether to apply Arrow casting when type mismatches + int_to_decimal_coercion_enabled : bool + Whether to enable int to decimal coercion + ignore_unexpected_complex_type_values : bool + Whether to ignore unexpected complex type values during conversion + error_class : str, optional + Custom error class for arrow type cast errors (e.g., "UDTF_ARROW_TYPE_CAST_ERROR") + + Returns + ------- + pyarrow.Array + """ + import pyarrow as pa + import pandas as pd + + if isinstance(series.dtype, pd.CategoricalDtype): + series = series.astype(series.dtype.categories.dtype) + + if arrow_type is not None: + dt = spark_type or from_arrow_type(arrow_type, prefer_timestamp_ntz=True) + conv = _create_converter_from_pandas( + dt, + timezone=timezone, + error_on_duplicated_field_names=False, + int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled, + ignore_unexpected_complex_type_values=ignore_unexpected_complex_type_values, + ) + series = conv(series) + + if hasattr(series.array, "__arrow_array__"): + mask = None + else: + mask = series.isnull() + + # For UDTF (error_class is set), use Arrow-specific error handling + if error_class is not None: + try: + try: + return pa.Array.from_pandas(series, mask=mask, type=arrow_type, safe=safecheck) + except pa.lib.ArrowException: + if arrow_cast: + return pa.Array.from_pandas(series, mask=mask).cast( + target_type=arrow_type, safe=safecheck + ) + else: + raise + except pa.lib.ArrowException: + raise PySparkRuntimeError( + errorClass=error_class, + messageParameters={ + "col_name": str(series.name), + "col_type": str(series.dtype), + "arrow_type": str(arrow_type), + }, + ) from None + + # For regular UDF, use standard error handling + try: + try: + return pa.Array.from_pandas(series, mask=mask, type=arrow_type, safe=safecheck) + except pa.lib.ArrowInvalid: + # Only catch ArrowInvalid for arrow_cast, not ArrowTypeError + # ArrowTypeError should propagate to the TypeError handler + if arrow_cast: + return pa.Array.from_pandas(series, mask=mask).cast( + target_type=arrow_type, safe=safecheck + ) + else: + raise + except (TypeError, pa.lib.ArrowTypeError) as e: + error_msg = ( + "Exception thrown when converting pandas.Series (%s) " + "with name '%s' to Arrow Array (%s)." + ) + raise PySparkTypeError(error_msg % (series.dtype, series.name, arrow_type)) from e + except (ValueError, pa.lib.ArrowException) as e: + error_msg = ( + "Exception thrown when converting pandas.Series (%s) " + "with name '%s' to Arrow Array (%s)." + ) + if safecheck: + error_msg += ( + " It can be caused by overflows or other " + "unsafe conversions warned by Arrow. Arrow safe type check " + "can be disabled by using SQL config " + "`spark.sql.execution.pandas.convertToArrowArraySafely`." + ) + raise PySparkValueError(error_msg % (series.dtype, series.name, arrow_type)) from e + + class ArrowArrayToPandasConversion: + """ + Conversion utilities for converting PyArrow Arrays and ChunkedArrays to pandas. + + This class provides methods to convert PyArrow columnar data structures to pandas + Series or DataFrames, with support for Spark-specific type handling and conversions. + + The class is primarily used by PySpark's Arrow-based serializers for UDF execution, + where Arrow data needs to be converted to pandas for Python UDF processing. + """ + @classmethod def convert_legacy( cls, @@ -1014,14 +1738,39 @@ def convert_legacy( df_for_struct: bool = False, ) -> Union["pd.Series", "pd.DataFrame"]: """ + Convert a PyArrow Array or ChunkedArray to a pandas Series or DataFrame. + + This is the lower-level conversion method that requires explicit Spark type + specification. For a more convenient API, see :meth:`convert`. + Parameters ---------- - arr : :class:`pyarrow.Array`. - spark_type: target spark type, should always be specified. - timezone : The timezone to convert from. If there is a timestamp type, it's required. - struct_in_pandas : How to handle struct type. If there is a struct type, it's required. - ndarray_as_list : Whether `np.ndarray` is converted to a list or not. - df_for_struct: when true, and spark type is a StructType, return a DataFrame. + arr : pa.Array or pa.ChunkedArray + The arrow column to convert. + spark_type : DataType + Target Spark type. Must be specified and should match the Arrow array type. + timezone : str, optional + The timezone to use for timestamp conversion. Required if the data contains + timestamp types. + struct_in_pandas : str, optional + How to handle struct types in pandas. Valid values are "dict", "row", or "legacy". + Required if the data contains struct types. + ndarray_as_list : bool, optional + Whether to convert numpy ndarrays to Python lists. Default is False. + df_for_struct : bool, optional + If True and spark_type is a StructType, return a DataFrame with columns + corresponding to struct fields instead of a Series. Default is False. + + Returns + ------- + pd.Series or pd.DataFrame + Converted pandas Series. If df_for_struct is True and spark_type is StructType, + returns a DataFrame with columns corresponding to struct fields. + + Notes + ----- + This method handles date type columns specially to avoid overflow issues with + datetime64[ns] intermediate representations. """ import pyarrow as pa import pandas as pd @@ -1032,7 +1781,11 @@ def convert_legacy( import pyarrow.types as types assert types.is_struct(arr.type) - assert len(spark_type.names) == len(arr.type.names), f"{spark_type} {arr.type} " + assert len(spark_type.names) == len(arr.type.names), ( + f"Schema mismatch: spark_type has {len(spark_type.names)} fields, " + f"but arrow type has {len(arr.type.names)} fields. " + f"spark_type={spark_type}, arrow_type={arr.type}" + ) series = [ cls.convert_legacy( @@ -1049,10 +1802,11 @@ def convert_legacy( pdf.columns = spark_type.names # type: ignore[assignment] return pdf - # If the given column is a date type column, creates a series of datetime.date directly - # instead of creating datetime64[ns] as intermediate data to avoid overflow caused by - # datetime64[ns] type handling. - # Cast dates to objects instead of datetime64[ns] dtype to avoid overflow. + # Convert Arrow array to pandas Series with specific options: + # - date_as_object: Convert date types to Python datetime.date objects directly + # instead of datetime64[ns] to avoid overflow issues + # - coerce_temporal_nanoseconds: Handle nanosecond precision timestamps correctly + # - integer_object_nulls: Use object dtype for integer arrays with nulls pandas_options = { "date_as_object": True, "coerce_temporal_nanoseconds": True, @@ -1070,3 +1824,49 @@ def convert_legacy( integer_object_nulls=True, ) return converter(ser) + + @classmethod + def convert( + cls, + arrow_column: Union["pa.Array", "pa.ChunkedArray"], + target_type: DataType, + *, + timezone: Optional[str] = None, + struct_in_pandas: str = "dict", + ndarray_as_list: bool = False, + df_for_struct: bool = False, + ) -> Union["pd.Series", "pd.DataFrame"]: + """ + Convert a PyArrow Array or ChunkedArray to a pandas Series or DataFrame. + + Parameters + ---------- + arrow_column : pa.Array or pa.ChunkedArray + The Arrow column to convert. + target_type : DataType + The target Spark type for the column to be converted to. + timezone : str, optional + Timezone for timestamp conversion. Required if the data contains timestamp types. + struct_in_pandas : str, optional + How to represent struct types in pandas. Valid values are "dict", "row", or "legacy". + Default is "dict". + ndarray_as_list : bool, optional + Whether to convert numpy ndarrays to Python lists. Default is False. + df_for_struct : bool, optional + If True, convert struct columns to a DataFrame with columns corresponding + to struct fields instead of a Series. Default is False. + + Returns + ------- + pd.Series or pd.DataFrame + Converted pandas Series. If df_for_struct is True and the type is StructType, + returns a DataFrame with columns corresponding to struct fields. + """ + return cls.convert_legacy( + arrow_column, + target_type, + timezone=timezone, + struct_in_pandas=struct_in_pandas, + ndarray_as_list=ndarray_as_list, + df_for_struct=df_for_struct, + ) diff --git a/python/pyspark/sql/pandas/conversion.py b/python/pyspark/sql/pandas/conversion.py index 333f9803df3ab..7e524b294a18e 100644 --- a/python/pyspark/sql/pandas/conversion.py +++ b/python/pyspark/sql/pandas/conversion.py @@ -18,6 +18,7 @@ from typing import ( Any, Callable, + Iterator, List, Optional, Sequence, @@ -807,7 +808,8 @@ def _create_from_pandas_with_arrow( assert isinstance(self, SparkSession) - from pyspark.sql.pandas.serializers import ArrowStreamPandasSerializer + from pyspark.sql.pandas.serializers import ArrowStreamSerializer + from pyspark.sql.conversion import PandasBatchTransformer from pyspark.sql.types import TimestampType from pyspark.sql.pandas.types import ( from_arrow_type, @@ -895,7 +897,14 @@ def _create_from_pandas_with_arrow( jsparkSession = self._jsparkSession - ser = ArrowStreamPandasSerializer(timezone, safecheck, False) + # Convert pandas data to Arrow batches + def create_batches() -> Iterator["pa.RecordBatch"]: + for batch_data in arrow_data: + yield PandasBatchTransformer.to_arrow( + batch_data, timezone=timezone, safecheck=safecheck + ) + + ser = ArrowStreamSerializer() @no_type_check def reader_func(temp_filename): @@ -906,7 +915,7 @@ def create_iter_server(): return self._jvm.ArrowIteratorServer() # Create Spark DataFrame from Arrow stream file, using one batch per partition - jiter = self._sc._serialize_to_jvm(arrow_data, ser, reader_func, create_iter_server) + jiter = self._sc._serialize_to_jvm(create_batches(), ser, reader_func, create_iter_server) assert self._jvm is not None jdf = self._jvm.PythonSQLUtils.toDataFrame(jiter, schema.json(), jsparkSession) df = DataFrame(jdf, self) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 74599869548d3..58533188017a8 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -20,10 +20,9 @@ """ from itertools import groupby -from typing import TYPE_CHECKING, Iterator, Optional import pyspark -from pyspark.errors import PySparkRuntimeError, PySparkTypeError, PySparkValueError +from pyspark.errors import PySparkRuntimeError, PySparkValueError from pyspark.serializers import ( Serializer, read_int, @@ -31,21 +30,14 @@ UTF8Deserializer, CPickleSerializer, ) -from pyspark.sql import Row from pyspark.sql.conversion import ( LocalDataToArrowConversion, ArrowTableToRowsConversion, - ArrowArrayToPandasConversion, ArrowBatchTransformer, + PandasBatchTransformer, ) -from pyspark.sql.pandas.types import ( - from_arrow_type, - is_variant, - to_arrow_type, - _create_converter_from_pandas, -) +from pyspark.sql.pandas.types import to_arrow_type from pyspark.sql.types import ( - DataType, StringType, StructType, BinaryType, @@ -54,10 +46,6 @@ IntegerType, ) -if TYPE_CHECKING: - import pandas as pd - import pyarrow as pa - class SpecialLengths: END_OF_DATA_SECTION = -1 @@ -213,589 +201,118 @@ def __repr__(self): return "ArrowStreamSerializer" -class ArrowStreamUDFSerializer(ArrowStreamSerializer): - """ - Same as :class:`ArrowStreamSerializer` but it flattens the struct to Arrow record batch - for applying each function with the raw record arrow batch. See also `DataFrame.mapInArrow`. - """ - - def load_stream(self, stream): - """ - Flatten the struct into Arrow's record batches. - """ - batches = super().load_stream(stream) - flattened = map(ArrowBatchTransformer.flatten_struct, batches) - return map(lambda b: [b], flattened) - - def dump_stream(self, iterator, stream): - """ - Override because Pandas UDFs require a START_ARROW_STREAM before the Arrow stream is sent. - """ - batches = self._write_stream_start( - (ArrowBatchTransformer.wrap_struct(x[0]) for x in iterator), stream - ) - return super().dump_stream(batches, stream) - - -class ArrowStreamUDTFSerializer(ArrowStreamUDFSerializer): - """ - Same as :class:`ArrowStreamUDFSerializer` but it does not flatten when loading batches. +class ArrowStreamGroupSerializer(ArrowStreamSerializer): """ + Unified serializer for Arrow stream operations with optional grouping support. - def load_stream(self, stream): - return ArrowStreamSerializer.load_stream(self, stream) + This serializer handles: + - Non-grouped (num_dfs=0): SQL_MAP_ARROW_ITER_UDF, SQL_ARROW_TABLE_UDF, SQL_ARROW_UDTF + - Grouped (num_dfs=1): SQL_GROUPED_MAP_ARROW_UDF, SQL_GROUPED_MAP_PANDAS_UDF, + SQL_GROUPED_AGG_ARROW_UDF, SQL_GROUPED_AGG_PANDAS_UDF + - Cogrouped (num_dfs=2): SQL_COGROUPED_MAP_ARROW_UDF, SQL_COGROUPED_MAP_PANDAS_UDF + The serializer handles Arrow stream I/O and START signal, while transformation logic + (flatten/wrap struct, pandas conversion) is handled by worker wrappers. -class ArrowStreamArrowUDTFSerializer(ArrowStreamUDTFSerializer): - """ - Serializer for PyArrow-native UDTFs that work directly with PyArrow RecordBatches and Arrays. + Parameters + ---------- + num_dfs : int, optional + Number of DataFrames per group: + - 0: Non-grouped mode (default) - yields all batches as single stream + - 1: Grouped mode - yields one iterator of batches per group + - 2: Cogrouped mode - yields tuple of two iterators per group """ - def __init__(self, table_arg_offsets=None): + def __init__(self, num_dfs: int = 0): super().__init__() - self.table_arg_offsets = table_arg_offsets if table_arg_offsets else [] + self._num_dfs = num_dfs def load_stream(self, stream): """ - Flatten the struct into Arrow's record batches. - """ - for batch in super().load_stream(stream): - # For each column: flatten struct columns at table_arg_offsets into RecordBatch, - # keep other columns as Array - yield [ - ArrowBatchTransformer.flatten_struct(batch, column_index=i) - if i in self.table_arg_offsets - else batch.column(i) - for i in range(batch.num_columns) - ] - - def _create_array(self, arr, arrow_type): - import pyarrow as pa + Deserialize Arrow record batches from stream. - assert isinstance(arr, pa.Array) - assert isinstance(arrow_type, pa.DataType) - if arr.type == arrow_type: - return arr + Returns + ------- + Iterator + - num_dfs=0: Iterator[pa.RecordBatch] - all batches in a single stream + - num_dfs=1: Iterator[Iterator[pa.RecordBatch]] - one iterator per group + - num_dfs=2: Iterator[Tuple[Iterator[pa.RecordBatch], Iterator[pa.RecordBatch]]] - + tuple of two iterators per cogrouped group + """ + if self._num_dfs > 0: + # Grouped mode: return raw Arrow batches per group + for batch_iters in self._load_group_dataframes(stream, num_dfs=self._num_dfs): + if self._num_dfs == 1: + yield batch_iters[0] + # Make sure the batches are fully iterated before getting the next group + for _ in batch_iters[0]: + pass + else: + yield batch_iters else: - try: - # when safe is True, the cast will fail if there's a overflow or other - # unsafe conversion. - # RecordBatch.cast(...) isn't used as minimum PyArrow version - # required for RecordBatch.cast(...) is v16.0 - return arr.cast(target_type=arrow_type, safe=True) - except (pa.ArrowInvalid, pa.ArrowTypeError): - raise PySparkRuntimeError( - errorClass="RESULT_COLUMNS_MISMATCH_FOR_ARROW_UDTF", - messageParameters={ - "expected": str(arrow_type), - "actual": str(arr.type), - }, - ) + # Non-grouped mode: return raw Arrow batches + yield from super().load_stream(stream) def dump_stream(self, iterator, stream): """ - Override to handle type coercion for ArrowUDTF outputs. - ArrowUDTF returns iterator of (pa.RecordBatch, arrow_return_type) tuples. - """ - import pyarrow as pa - - def apply_type_coercion(): - for batch, arrow_return_type in iterator: - assert isinstance( - arrow_return_type, pa.StructType - ), f"Expected pa.StructType, got {type(arrow_return_type)}" - - # Handle empty struct case specially - if batch.num_columns == 0: - coerced_batch = batch # skip type coercion - else: - expected_field_names = [field.name for field in arrow_return_type] - actual_field_names = batch.schema.names - - if expected_field_names != actual_field_names: - raise PySparkTypeError( - "Target schema's field names are not matching the record batch's " - "field names. " - f"Expected: {expected_field_names}, but got: {actual_field_names}." - ) - - coerced_arrays = [] - for i, field in enumerate(arrow_return_type): - original_array = batch.column(i) - coerced_array = self._create_array(original_array, field.type) - coerced_arrays.append(coerced_array) - coerced_batch = pa.RecordBatch.from_arrays( - coerced_arrays, names=expected_field_names - ) - yield coerced_batch, arrow_return_type - - return super().dump_stream(apply_type_coercion(), stream) - - -class ArrowStreamGroupUDFSerializer(ArrowStreamUDFSerializer): - """ - Serializer for grouped Arrow UDFs. - - Deserializes: - ``Iterator[Iterator[pa.RecordBatch]]`` - one inner iterator per group. - Each batch contains a single struct column. - - Serializes: - ``Iterator[Tuple[Iterator[pa.RecordBatch], pa.DataType]]`` - Each tuple contains iterator of flattened batches and their Arrow type. - - Used by: - - SQL_GROUPED_MAP_ARROW_UDF - - SQL_GROUPED_MAP_ARROW_ITER_UDF - - Parameters - ---------- - assign_cols_by_name : bool - If True, reorder serialized columns by schema name. - """ + Serialize Arrow record batches to stream with START signal. - def __init__(self, assign_cols_by_name): - super().__init__() - self._assign_cols_by_name = assign_cols_by_name + The START_ARROW_STREAM marker is sent before the first batch to signal + the JVM that Arrow data is about to be transmitted. This allows proper + error handling during batch creation. - def load_stream(self, stream): - """ - Load grouped Arrow record batches from stream. + Parameters + ---------- + iterator : Iterator[pa.RecordBatch] + Iterator of Arrow RecordBatches to serialize. For grouped/cogrouped UDFs, + this iterator is already flattened by the worker's func layer. + stream : file-like object + Output stream to write the serialized batches """ - for (batches,) in self._load_group_dataframes(stream, num_dfs=1): - yield batches - # Make sure the batches are fully iterated before getting the next group - for _ in batches: - pass - - def dump_stream(self, iterator, stream): - import pyarrow as pa - - # flatten inner list [([pa.RecordBatch], arrow_type)] into [(pa.RecordBatch, arrow_type)] - # so strip off inner iterator induced by ArrowStreamUDFSerializer.load_stream - batch_iter = ( - (batch, arrow_type) - for batches, arrow_type in iterator # tuple constructed in wrap_grouped_map_arrow_udf - for batch in batches - ) - - if self._assign_cols_by_name: - batch_iter = ( - ( - pa.RecordBatch.from_arrays( - [batch.column(field.name) for field in arrow_type], - names=[field.name for field in arrow_type], - ), - arrow_type, - ) - for batch, arrow_type in batch_iter - ) - - super().dump_stream(batch_iter, stream) + batches = self._write_stream_start(iterator, stream) + return super().dump_stream(batches, stream) -class ArrowStreamPandasSerializer(ArrowStreamSerializer): +class ArrowStreamUDFSerializer(ArrowStreamSerializer): """ - Serializes pandas.Series as Arrow data with Arrow streaming format. + Serializer for UDFs that handles Arrow RecordBatch serialization. + + This is a thin wrapper around ArrowStreamSerializer kept for backward compatibility + and future extensibility. Currently it doesn't override any methods - all conversion + logic has been moved to wrappers/callers, and this serializer only handles pure + Arrow stream serialization. Parameters ---------- timezone : str - A timezone to respect when handling timestamp values + A timezone to respect when handling timestamp values (for compatibility) safecheck : bool - If True, conversion from Arrow to Pandas checks for overflow/truncation - assign_cols_by_name : bool - If True, then Pandas DataFrames will get columns by name + If True, conversion checks for overflow/truncation (for compatibility) int_to_decimal_coercion_enabled : bool - If True, applies additional coercions in Python before converting to Arrow - This has performance penalties. - """ - - def __init__(self, timezone, safecheck, int_to_decimal_coercion_enabled): - super().__init__() - self._timezone = timezone - self._safecheck = safecheck - self._int_to_decimal_coercion_enabled = int_to_decimal_coercion_enabled - - def arrow_to_pandas( - self, arrow_column, idx, struct_in_pandas="dict", ndarray_as_list=False, spark_type=None - ): - return ArrowArrayToPandasConversion.convert_legacy( - arrow_column, - spark_type or from_arrow_type(arrow_column.type), - timezone=self._timezone, - struct_in_pandas=struct_in_pandas, - ndarray_as_list=ndarray_as_list, - df_for_struct=False, - ) - - def _create_array(self, series, arrow_type, spark_type=None, arrow_cast=False): - """ - Create an Arrow Array from the given pandas.Series and optional type. - - Parameters - ---------- - series : pandas.Series - A single series - arrow_type : pyarrow.DataType, optional - If None, pyarrow's inferred type will be used - spark_type : DataType, optional - If None, spark type converted from arrow_type will be used - arrow_cast: bool, optional - Whether to apply Arrow casting when the user-specified return type mismatches the - actual return values. - - Returns - ------- - pyarrow.Array - """ - import pyarrow as pa - import pandas as pd - - if isinstance(series.dtype, pd.CategoricalDtype): - series = series.astype(series.dtypes.categories.dtype) - - if arrow_type is not None: - dt = spark_type or from_arrow_type(arrow_type, prefer_timestamp_ntz=True) - conv = _create_converter_from_pandas( - dt, - timezone=self._timezone, - error_on_duplicated_field_names=False, - int_to_decimal_coercion_enabled=self._int_to_decimal_coercion_enabled, - ) - series = conv(series) - - if hasattr(series.array, "__arrow_array__"): - mask = None - else: - mask = series.isnull() - try: - try: - return pa.Array.from_pandas( - series, mask=mask, type=arrow_type, safe=self._safecheck - ) - except pa.lib.ArrowInvalid: - if arrow_cast: - return pa.Array.from_pandas(series, mask=mask).cast( - target_type=arrow_type, safe=self._safecheck - ) - else: - raise - except TypeError as e: - error_msg = ( - "Exception thrown when converting pandas.Series (%s) " - "with name '%s' to Arrow Array (%s)." - ) - raise PySparkTypeError(error_msg % (series.dtype, series.name, arrow_type)) from e - except ValueError as e: - error_msg = ( - "Exception thrown when converting pandas.Series (%s) " - "with name '%s' to Arrow Array (%s)." - ) - if self._safecheck: - error_msg = error_msg + ( - " It can be caused by overflows or other " - "unsafe conversions warned by Arrow. Arrow safe type check " - "can be disabled by using SQL config " - "`spark.sql.execution.pandas.convertToArrowArraySafely`." - ) - raise PySparkValueError(error_msg % (series.dtype, series.name, arrow_type)) from e - - def _create_batch(self, series): - """ - Create an Arrow record batch from the given pandas.Series or list of Series, - with optional type. - - Parameters - ---------- - series : pandas.Series or list - A single series, list of series, or list of (series, arrow_type) - - Returns - ------- - pyarrow.RecordBatch - Arrow RecordBatch - """ - import pyarrow as pa - - # Make input conform to - # [(series1, arrow_type1, spark_type1), (series2, arrow_type2, spark_type2), ...] - if ( - not isinstance(series, (list, tuple)) - or (len(series) == 2 and isinstance(series[1], pa.DataType)) - or ( - len(series) == 3 - and isinstance(series[1], pa.DataType) - and isinstance(series[2], DataType) - ) - ): - series = [series] - series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series) - series = ((s[0], s[1], None) if len(s) == 2 else s for s in series) - - arrs = [ - self._create_array(s, arrow_type, spark_type) for s, arrow_type, spark_type in series - ] - return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in range(len(arrs))]) - - def dump_stream(self, iterator, stream): - """ - Make ArrowRecordBatches from Pandas Series and serialize. Input is a single series or - a list of series accompanied by an optional pyarrow type to coerce the data to. - """ - batches = (self._create_batch(series) for series in iterator) - super().dump_stream(batches, stream) - - def load_stream(self, stream): - """ - Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series. - """ - batches = super().load_stream(stream) - import pandas as pd - - for batch in batches: - pandas_batches = [ - self.arrow_to_pandas(batch.column(i), i) for i in range(batch.num_columns) - ] - if len(pandas_batches) == 0: - yield [pd.Series([pyspark._NoValue] * batch.num_rows)] - else: - yield pandas_batches - - def __repr__(self): - return "ArrowStreamPandasSerializer" - - -class ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer): - """ - Serializer used by Python worker to evaluate Pandas UDFs + If True, applies additional coercions (for compatibility) """ def __init__( self, timezone, safecheck, - assign_cols_by_name, - df_for_struct: bool = False, - struct_in_pandas: str = "dict", - ndarray_as_list: bool = False, - arrow_cast: bool = False, - input_type: Optional[StructType] = None, int_to_decimal_coercion_enabled: bool = False, - ): - super().__init__(timezone, safecheck, int_to_decimal_coercion_enabled) - self._assign_cols_by_name = assign_cols_by_name - self._df_for_struct = df_for_struct - self._struct_in_pandas = struct_in_pandas - self._ndarray_as_list = ndarray_as_list - self._arrow_cast = arrow_cast - if input_type is not None: - assert isinstance(input_type, StructType) - self._input_type = input_type - - def arrow_to_pandas(self, arrow_column, idx): - if self._input_type is not None: - spark_type = self._input_type[idx].dataType - else: - spark_type = from_arrow_type(arrow_column.type) - - return ArrowArrayToPandasConversion.convert_legacy( - arr=arrow_column, - spark_type=spark_type, - timezone=self._timezone, - struct_in_pandas=self._struct_in_pandas, - ndarray_as_list=self._ndarray_as_list, - df_for_struct=self._df_for_struct, - ) - - def _create_struct_array( - self, - df: "pd.DataFrame", - arrow_struct_type: "pa.StructType", - spark_type: Optional[StructType] = None, - ): - """ - Create an Arrow StructArray from the given pandas.DataFrame and arrow struct type. - - Parameters - ---------- - df : pandas.DataFrame - A pandas DataFrame - arrow_struct_type : pyarrow.StructType - pyarrow struct type - - Returns - ------- - pyarrow.Array - """ - import pyarrow as pa - - if len(df.columns) == 0: - return pa.array([{}] * len(df), arrow_struct_type) - # Assign result columns by schema name if user labeled with strings - if self._assign_cols_by_name and any(isinstance(name, str) for name in df.columns): - struct_arrs = [ - self._create_array( - df[field.name], - field.type, - spark_type=( - spark_type[field.name].dataType if spark_type is not None else None - ), - arrow_cast=self._arrow_cast, - ) - for field in arrow_struct_type - ] - # Assign result columns by position - else: - struct_arrs = [ - # the selected series has name '1', so we rename it to field.name - # as the name is used by _create_array to provide a meaningful error message - self._create_array( - df[df.columns[i]].rename(field.name), - field.type, - spark_type=spark_type[i].dataType if spark_type is not None else None, - arrow_cast=self._arrow_cast, - ) - for i, field in enumerate(arrow_struct_type) - ] - - return pa.StructArray.from_arrays(struct_arrs, fields=list(arrow_struct_type)) - - def _create_batch(self, series): - """ - Create an Arrow record batch from the given pandas.Series pandas.DataFrame - or list of Series or DataFrame, with optional type. - - Parameters - ---------- - series : pandas.Series or pandas.DataFrame or list - A single series or dataframe, list of series or dataframe, - or list of (series or dataframe, arrow_type) - - Returns - ------- - pyarrow.RecordBatch - Arrow RecordBatch - """ - import pandas as pd - import pyarrow as pa - - # Make input conform to - # [(series1, arrow_type1, spark_type1), (series2, arrow_type2, spark_type2), ...] - if ( - not isinstance(series, (list, tuple)) - or (len(series) == 2 and isinstance(series[1], pa.DataType)) - or ( - len(series) == 3 - and isinstance(series[1], pa.DataType) - and isinstance(series[2], DataType) - ) - ): - series = [series] - series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series) - series = ((s[0], s[1], None) if len(s) == 2 else s for s in series) - - arrs = [] - for s, arrow_type, spark_type in series: - # Variants are represented in arrow as structs with additional metadata (checked by - # is_variant). If the data type is Variant, return a VariantVal atomic type instead of - # a dict of two binary values. - if ( - self._struct_in_pandas == "dict" - and arrow_type is not None - and pa.types.is_struct(arrow_type) - and not is_variant(arrow_type) - ): - # A pandas UDF should return pd.DataFrame when the return type is a struct type. - # If it returns a pd.Series, it should throw an error. - if not isinstance(s, pd.DataFrame): - raise PySparkValueError( - "Invalid return type. Please make sure that the UDF returns a " - "pandas.DataFrame when the specified return type is StructType." - ) - arrs.append(self._create_struct_array(s, arrow_type, spark_type=spark_type)) - else: - arrs.append( - self._create_array( - s, arrow_type, spark_type=spark_type, arrow_cast=self._arrow_cast - ) - ) - - return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in range(len(arrs))]) - - def dump_stream(self, iterator, stream): - """ - Override because Pandas UDFs require a START_ARROW_STREAM before the Arrow stream is sent. - This should be sent after creating the first record batch so in case of an error, it can - be sent back to the JVM before the Arrow stream starts. - """ - batches = self._write_stream_start(map(self._create_batch, iterator), stream) - return ArrowStreamSerializer.dump_stream(self, batches, stream) - - def __repr__(self): - return "ArrowStreamPandasUDFSerializer" - - -class ArrowStreamArrowUDFSerializer(ArrowStreamSerializer): - """ - Serializer used by Python worker to evaluate Arrow UDFs - """ - - def __init__( - self, - safecheck, - arrow_cast, ): super().__init__() + # Store parameters for backward compatibility + self._timezone = timezone self._safecheck = safecheck - self._arrow_cast = arrow_cast - - def _create_array(self, arr, arrow_type, arrow_cast): - import pyarrow as pa - - assert isinstance(arr, pa.Array) - assert isinstance(arrow_type, pa.DataType) - - if arr.type == arrow_type: - return arr - elif arrow_cast: - return arr.cast(target_type=arrow_type, safe=self._safecheck) - else: - raise PySparkTypeError( - "Arrow UDFs require the return type to match the expected Arrow type. " - f"Expected: {arrow_type}, but got: {arr.type}." - ) - - def dump_stream(self, iterator, stream): - """ - Override because Arrow UDFs require a START_ARROW_STREAM before the Arrow stream is sent. - This should be sent after creating the first record batch so in case of an error, it can - be sent back to the JVM before the Arrow stream starts. - """ - import pyarrow as pa - - def create_batches(): - for packed in iterator: - if len(packed) == 2 and isinstance(packed[1], pa.DataType): - # single array UDF in a projection - arrs = [self._create_array(packed[0], packed[1], self._arrow_cast)] - else: - # multiple array UDFs in a projection - arrs = [self._create_array(t[0], t[1], self._arrow_cast) for t in packed] - yield pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in range(len(arrs))]) - - batches = self._write_stream_start(create_batches(), stream) - return ArrowStreamSerializer.dump_stream(self, batches, stream) + self._int_to_decimal_coercion_enabled = int_to_decimal_coercion_enabled def __repr__(self): - return "ArrowStreamArrowUDFSerializer" + return "ArrowStreamUDFSerializer" -class ArrowBatchUDFSerializer(ArrowStreamArrowUDFSerializer): +class ArrowBatchUDFSerializer(ArrowStreamGroupSerializer): """ Serializer used by Python worker to evaluate Arrow Python UDFs when the legacy pandas conversion is disabled - (instead of legacy ArrowStreamPandasUDFSerializer). + (instead of legacy ArrowStreamGroupSerializer with pandas conversion in serializer). Parameters ---------- @@ -817,11 +334,9 @@ def __init__( int_to_decimal_coercion_enabled: bool, binary_as_bytes: bool, ): - super().__init__( - safecheck=safecheck, - arrow_cast=True, - ) + super().__init__(num_dfs=0) assert isinstance(input_type, StructType) + self._safecheck = safecheck self._input_type = input_type self._int_to_decimal_coercion_enabled = int_to_decimal_coercion_enabled self._binary_as_bytes = binary_as_bytes @@ -848,14 +363,13 @@ def load_stream(self, stream): ] for batch in super().load_stream(stream): - columns = [ - [conv(v) for v in column.to_pylist()] if conv is not None else column.to_pylist() - for column, conv in zip(batch.itercolumns(), converters) - ] - if len(columns) == 0: + if batch.num_columns == 0: yield [[pyspark._NoValue] * batch.num_rows] else: - yield columns + yield [ + [conv(v) for v in col.to_pylist()] if conv else col.to_pylist() + for col, conv in zip(batch.itercolumns(), converters) + ] def dump_stream(self, iterator, stream): """ @@ -872,353 +386,39 @@ def dump_stream(self, iterator, stream): Returns ------- object - Result of writing the Arrow stream via ArrowStreamArrowUDFSerializer dump_stream + Result of writing the Arrow stream via ArrowStreamGroupSerializer dump_stream """ import pyarrow as pa def create_array(results, arrow_type, spark_type): - conv = LocalDataToArrowConversion._create_converter( + return LocalDataToArrowConversion.create_array( + results, + arrow_type, spark_type, - none_on_identity=True, + safecheck=self._safecheck, int_to_decimal_coercion_enabled=self._int_to_decimal_coercion_enabled, ) - converted = [conv(res) for res in results] if conv is not None else results - try: - return pa.array(converted, type=arrow_type) - except pa.lib.ArrowInvalid: - return pa.array(converted).cast(target_type=arrow_type, safe=self._safecheck) def py_to_batch(): for packed in iterator: if len(packed) == 3 and isinstance(packed[1], pa.DataType): # single array UDF in a projection - yield create_array(packed[0], packed[1], packed[2]), packed[1] + yield ArrowBatchTransformer.zip_batches( + [(create_array(packed[0], packed[1], packed[2]), packed[1])], + safecheck=self._safecheck, + ) else: # multiple array UDFs in a projection - yield [(create_array(*t), t[1]) for t in packed] - - return super().dump_stream(py_to_batch(), stream) - - -class ArrowStreamPandasUDTFSerializer(ArrowStreamPandasUDFSerializer): - """ - Serializer used by Python worker to evaluate Arrow-optimized Python UDTFs. - """ - - def __init__(self, timezone, safecheck, input_type, int_to_decimal_coercion_enabled): - super().__init__( - timezone=timezone, - safecheck=safecheck, - # The output pandas DataFrame's columns are unnamed. - assign_cols_by_name=False, - # Set to 'False' to avoid converting struct type inputs into a pandas DataFrame. - df_for_struct=False, - # Defines how struct type inputs are converted. If set to "row", struct type inputs - # are converted into Rows. Without this setting, a struct type input would be treated - # as a dictionary. For example, for named_struct('name', 'Alice', 'age', 1), - # if struct_in_pandas="dict", it becomes {"name": "Alice", "age": 1} - # if struct_in_pandas="row", it becomes Row(name="Alice", age=1) - struct_in_pandas="row", - # When dealing with array type inputs, Arrow converts them into numpy.ndarrays. - # To ensure consistency across regular and arrow-optimized UDTFs, we further - # convert these numpy.ndarrays into Python lists. - ndarray_as_list=True, - # Enables explicit casting for mismatched return types of Arrow Python UDTFs. - arrow_cast=True, - input_type=input_type, - # Enable additional coercions for UDTF serialization - int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled, - ) - - def _create_batch(self, series): - """ - Create an Arrow record batch from the given pandas.Series pandas.DataFrame - or list of Series or DataFrame, with optional type. - - Parameters - ---------- - series : pandas.Series or pandas.DataFrame or list - A single series or dataframe, list of series or dataframe, - or list of (series or dataframe, arrow_type) - - Returns - ------- - pyarrow.RecordBatch - Arrow RecordBatch - """ - import pandas as pd - import pyarrow as pa - - # Make input conform to - # [(series1, arrow_type1, spark_type1), (series2, arrow_type2, spark_type2), ...] - if ( - not isinstance(series, (list, tuple)) - or (len(series) == 2 and isinstance(series[1], pa.DataType)) - or ( - len(series) == 3 - and isinstance(series[1], pa.DataType) - and isinstance(series[2], DataType) - ) - ): - series = [series] - series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series) - series = ((s[0], s[1], None) if len(s) == 2 else s for s in series) - - arrs = [] - for s, arrow_type, spark_type in series: - if not isinstance(s, pd.DataFrame): - raise PySparkValueError( - "Output of an arrow-optimized Python UDTFs expects " - f"a pandas.DataFrame but got: {type(s)}" - ) - - arrs.append(self._create_struct_array(s, arrow_type, spark_type)) - - return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in range(len(arrs))]) - - def _create_array(self, series, arrow_type, spark_type=None, arrow_cast=False): - """ - Override the `_create_array` method in the superclass to create an Arrow Array - from a given pandas.Series and an arrow type. The difference here is that we always - use arrow cast when creating the arrow array. Also, the error messages are specific - to arrow-optimized Python UDTFs. - - Parameters - ---------- - series : pandas.Series - A single series - arrow_type : pyarrow.DataType, optional - If None, pyarrow's inferred type will be used - spark_type : DataType, optional - If None, spark type converted from arrow_type will be used - arrow_cast: bool, optional - Whether to apply Arrow casting when the user-specified return type mismatches the - actual return values. - - Returns - ------- - pyarrow.Array - """ - import pyarrow as pa - import pandas as pd - - if isinstance(series.dtype, pd.CategoricalDtype): - series = series.astype(series.dtypes.categories.dtype) - - if arrow_type is not None: - dt = spark_type or from_arrow_type(arrow_type, prefer_timestamp_ntz=True) - conv = _create_converter_from_pandas( - dt, - timezone=self._timezone, - error_on_duplicated_field_names=False, - ignore_unexpected_complex_type_values=True, - int_to_decimal_coercion_enabled=self._int_to_decimal_coercion_enabled, - ) - series = conv(series) - - if hasattr(series.array, "__arrow_array__"): - mask = None - else: - mask = series.isnull() - - try: - try: - return pa.Array.from_pandas( - series, mask=mask, type=arrow_type, safe=self._safecheck - ) - except pa.lib.ArrowException: - if arrow_cast: - return pa.Array.from_pandas(series, mask=mask).cast( - target_type=arrow_type, safe=self._safecheck + arrays_and_types = [(create_array(*t), t[1]) for t in packed] + yield ArrowBatchTransformer.zip_batches( + arrays_and_types, + safecheck=self._safecheck, ) - else: - raise - except pa.lib.ArrowException: - # Display the most user-friendly error messages instead of showing - # arrow's error message. This also works better with Spark Connect - # where the exception messages are by default truncated. - raise PySparkRuntimeError( - errorClass="UDTF_ARROW_TYPE_CAST_ERROR", - messageParameters={ - "col_name": series.name, - "col_type": str(series.dtype), - "arrow_type": arrow_type, - }, - ) from None - - def __repr__(self): - return "ArrowStreamPandasUDTFSerializer" - - -# Serializer for SQL_GROUPED_AGG_ARROW_UDF, SQL_WINDOW_AGG_ARROW_UDF, -# and SQL_GROUPED_AGG_ARROW_ITER_UDF -class ArrowStreamAggArrowUDFSerializer(ArrowStreamArrowUDFSerializer): - def load_stream(self, stream): - """ - Yield an iterator that produces one tuple of column arrays per batch. - Each group yields Iterator[Tuple[pa.Array, ...]], allowing UDF to process batches one by one - without consuming all batches upfront. - """ - for (batches,) in self._load_group_dataframes(stream, num_dfs=1): - # Lazily read and convert Arrow batches one at a time from the stream - # This avoids loading all batches into memory for the group - columns_iter = (batch.columns for batch in batches) - yield columns_iter - # Make sure the batches are fully iterated before getting the next group - for _ in columns_iter: - pass - - def __repr__(self): - return "ArrowStreamAggArrowUDFSerializer" - - -# Serializer for SQL_GROUPED_AGG_PANDAS_UDF, SQL_WINDOW_AGG_PANDAS_UDF, -# and SQL_GROUPED_AGG_PANDAS_ITER_UDF -class ArrowStreamAggPandasUDFSerializer(ArrowStreamPandasUDFSerializer): - def __init__( - self, - timezone, - safecheck, - assign_cols_by_name, - int_to_decimal_coercion_enabled, - ): - super().__init__( - timezone=timezone, - safecheck=safecheck, - assign_cols_by_name=assign_cols_by_name, - df_for_struct=False, - struct_in_pandas="dict", - ndarray_as_list=False, - arrow_cast=True, - input_type=None, - int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled, - ) - - def load_stream(self, stream): - """ - Yield an iterator that produces one tuple of pandas.Series per batch. - Each group yields Iterator[Tuple[pd.Series, ...]], allowing UDF to - process batches one by one without consuming all batches upfront. - """ - for (batches,) in self._load_group_dataframes(stream, num_dfs=1): - # Lazily read and convert Arrow batches to pandas Series one at a time - # from the stream. This avoids loading all batches into memory for the group - series_iter = ( - tuple(self.arrow_to_pandas(c, i) for i, c in enumerate(batch.columns)) - for batch in batches - ) - yield series_iter - # Make sure the batches are fully iterated before getting the next group - for _ in series_iter: - pass - def __repr__(self): - return "ArrowStreamAggPandasUDFSerializer" - - -# Serializer for SQL_GROUPED_MAP_PANDAS_UDF, SQL_GROUPED_MAP_PANDAS_ITER_UDF -class GroupPandasUDFSerializer(ArrowStreamPandasUDFSerializer): - def __init__( - self, - timezone, - safecheck, - assign_cols_by_name, - int_to_decimal_coercion_enabled, - ): - super().__init__( - timezone=timezone, - safecheck=safecheck, - assign_cols_by_name=assign_cols_by_name, - df_for_struct=False, - struct_in_pandas="dict", - ndarray_as_list=False, - arrow_cast=True, - input_type=None, - int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled, - ) - - def load_stream(self, stream): - """ - Deserialize Grouped ArrowRecordBatches and yield as Iterator[Iterator[pd.Series]]. - Each outer iterator element represents a group, containing an iterator of Series lists - (one list per batch). - """ - - def process_group(batches: "Iterator[pa.RecordBatch]"): - # Convert each Arrow batch to pandas Series list on-demand, yielding one list per batch - for batch in batches: - # The batch from ArrowStreamSerializer is already flattened (no struct wrapper) - series = [ - self.arrow_to_pandas(batch.column(i), i) for i in range(batch.num_columns) - ] - yield series - - for (batches,) in self._load_group_dataframes(stream, num_dfs=1): - # Lazily read and convert Arrow batches one at a time from the stream - # This avoids loading all batches into memory for the group - series_iter = process_group(batches) - yield series_iter - # Make sure the batches are fully iterated before getting the next group - for _ in series_iter: - pass - - def dump_stream(self, iterator, stream): - """ - Flatten the Iterator[Iterator[[(df, arrow_type)]]] returned by func. - The mapper returns Iterator[[(df, arrow_type)]], so we flatten one level - to match the parent's expected format Iterator[[(df, arrow_type)]]. - """ - # Flatten: Iterator[Iterator[[(df, arrow_type)]]] -> Iterator[[(df, arrow_type)]] - flattened_iter = (batch for generator in iterator for batch in generator) - super().dump_stream(flattened_iter, stream) - - def __repr__(self): - return "GroupPandasUDFSerializer" - - -class CogroupArrowUDFSerializer(ArrowStreamGroupUDFSerializer): - """ - Serializes pyarrow.RecordBatch data with Arrow streaming format. - - Loads Arrow record batches as `[([pa.RecordBatch], [pa.RecordBatch])]` (one tuple per group) - and serializes `[([pa.RecordBatch], arrow_type)]`. - - Parameters - ---------- - assign_cols_by_name : bool - If True, then DataFrames will get columns by name - """ - - def load_stream(self, stream): - """ - Deserialize Cogrouped ArrowRecordBatches and yield as two `pyarrow.RecordBatch`es. - """ - for left_batches, right_batches in self._load_group_dataframes(stream, num_dfs=2): - yield left_batches, right_batches - - -class CogroupPandasUDFSerializer(ArrowStreamPandasUDFSerializer): - def load_stream(self, stream): - """ - Deserialize Cogrouped ArrowRecordBatches to a tuple of Arrow tables and yield as two - lists of pandas.Series. - """ - import pyarrow as pa - - for left_batches, right_batches in self._load_group_dataframes(stream, num_dfs=2): - yield ( - [ - self.arrow_to_pandas(c, i) - for i, c in enumerate(pa.Table.from_batches(left_batches).itercolumns()) - ], - [ - self.arrow_to_pandas(c, i) - for i, c in enumerate(pa.Table.from_batches(right_batches).itercolumns()) - ], - ) + return super().dump_stream(py_to_batch(), stream) -class ApplyInPandasWithStateSerializer(ArrowStreamPandasUDFSerializer): +class ApplyInPandasWithStateSerializer(ArrowStreamUDFSerializer): """ Serializer used by Python worker to evaluate UDF for applyInPandasWithState. @@ -1247,12 +447,16 @@ def __init__( int_to_decimal_coercion_enabled, ): super().__init__( - timezone, - safecheck, - assign_cols_by_name, + timezone=timezone, + safecheck=safecheck, int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled, - arrow_cast=True, ) + self._df_for_struct = False + self._struct_in_pandas = "dict" + self._ndarray_as_list = False + self._input_type = None + self._assign_cols_by_name = assign_cols_by_name + self._arrow_cast = True self.pickleSer = CPickleSerializer() self.utf8_deserializer = UTF8Deserializer() self.state_object_schema = state_object_schema @@ -1374,8 +578,14 @@ def gen_data_and_state(batches): schema=state_schema, ) - state_arrow = pa.Table.from_batches([state_batch]).itercolumns() - state_pandas = [self.arrow_to_pandas(c, i) for i, c in enumerate(state_arrow)][0] + state_pandas = ArrowBatchTransformer.to_pandas( + state_batch, + timezone=self._timezone, + schema=None, + struct_in_pandas=self._struct_in_pandas, + ndarray_as_list=self._ndarray_as_list, + df_for_struct=self._df_for_struct, + )[0] for state_idx in range(0, len(state_pandas)): state_info_col = state_pandas.iloc[state_idx] @@ -1405,9 +615,14 @@ def gen_data_and_state(batches): state_for_current_group = state data_batch_for_group = data_batch.slice(data_start_offset, num_data_rows) - data_arrow = pa.Table.from_batches([data_batch_for_group]).itercolumns() - - data_pandas = [self.arrow_to_pandas(c, i) for i, c in enumerate(data_arrow)] + data_pandas = ArrowBatchTransformer.to_pandas( + data_batch_for_group, + timezone=self._timezone, + schema=None, + struct_in_pandas=self._struct_in_pandas, + ndarray_as_list=self._ndarray_as_list, + df_for_struct=self._df_for_struct, + ) # state info yield ( @@ -1415,7 +630,7 @@ def gen_data_and_state(batches): state, ) - _batches = super(ArrowStreamPandasSerializer, self).load_stream(stream) + _batches = super().load_stream(stream) data_state_generator = gen_data_and_state(_batches) @@ -1507,12 +722,18 @@ def construct_record_batch(pdfs, pdf_data_cnt, pdf_schema, state_pdfs, state_dat merged_pdf = pd.concat(pdfs, ignore_index=True) merged_state_pdf = pd.concat(state_pdfs, ignore_index=True) - return self._create_batch( + return PandasBatchTransformer.to_arrow( [ (count_pdf, self.result_count_pdf_arrow_type), (merged_pdf, pdf_schema), (merged_state_pdf, self.result_state_pdf_arrow_type), - ] + ], + timezone=self._timezone, + safecheck=self._safecheck, + int_to_decimal_coercion_enabled=self._int_to_decimal_coercion_enabled, + struct_in_pandas=self._struct_in_pandas, + assign_cols_by_name=self._assign_cols_by_name, + arrow_cast=self._arrow_cast, ) def serialize_batches(): @@ -1586,7 +807,7 @@ def serialize_batches(): return ArrowStreamSerializer.dump_stream(self, batches, stream) -class TransformWithStateInPandasSerializer(ArrowStreamPandasUDFSerializer): +class TransformWithStateInPandasSerializer(ArrowStreamUDFSerializer): """ Serializer used by Python worker to evaluate UDF for :meth:`pyspark.sql.GroupedData.transformWithStateInPandasSerializer`. @@ -1613,12 +834,16 @@ def __init__( int_to_decimal_coercion_enabled, ): super().__init__( - timezone, - safecheck, - assign_cols_by_name, + timezone=timezone, + safecheck=safecheck, int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled, - arrow_cast=True, ) + self._df_for_struct = False + self._struct_in_pandas = "dict" + self._ndarray_as_list = False + self._input_type = None + self._assign_cols_by_name = assign_cols_by_name + self._arrow_cast = True self.arrow_max_records_per_batch = ( arrow_max_records_per_batch if arrow_max_records_per_batch > 0 else 2**31 - 1 ) @@ -1650,7 +875,6 @@ def load_stream(self, stream): Please refer the doc of inner function `generate_data_batches` for more details how this function works in overall. """ - import pyarrow as pa import pandas as pd from pyspark.sql.streaming.stateful_processor_util import ( TransformWithStateInPandasFuncMode, @@ -1670,10 +894,14 @@ def generate_data_batches(batches): def row_stream(): for batch in batches: self._update_batch_size_stats(batch) - data_pandas = [ - self.arrow_to_pandas(c, i) - for i, c in enumerate(pa.Table.from_batches([batch]).itercolumns()) - ] + data_pandas = ArrowBatchTransformer.to_pandas( + batch, + timezone=self._timezone, + schema=self._input_type, + struct_in_pandas=self._struct_in_pandas, + ndarray_as_list=self._ndarray_as_list, + df_for_struct=self._df_for_struct, + ) for row in pd.concat(data_pandas, axis=1).itertuples(index=False): batch_key = tuple(row[s] for s in self.key_offsets) yield (batch_key, row) @@ -1691,7 +919,7 @@ def row_stream(): if rows: yield (batch_key, pd.DataFrame(rows)) - _batches = super(ArrowStreamPandasSerializer, self).load_stream(stream) + _batches = super().load_stream(stream) data_batches = generate_data_batches(_batches) for k, g in groupby(data_batches, key=lambda x: x[0]): @@ -1706,17 +934,37 @@ def dump_stream(self, iterator, stream): Read through an iterator of (iterator of pandas DataFrame), serialize them to Arrow RecordBatches, and write batches to stream. """ + from pyspark.sql.conversion import ArrowBatchTransformer - def flatten_iterator(): - # iterator: iter[list[(iter[pandas.DataFrame], pdf_type)]] + def create_batches(): + # iterator: iter[list[(iter[pandas.DataFrame], arrow_return_type)]] for packed in iterator: - iter_pdf_with_type = packed[0] - iter_pdf = iter_pdf_with_type[0] - pdf_type = iter_pdf_with_type[1] + iter_pdf, arrow_return_type = packed[0] for pdf in iter_pdf: - yield (pdf, pdf_type) + # Extract columns from DataFrame and pair with Arrow types + # Similar to wrap_grouped_map_pandas_udf pattern + if self._assign_cols_by_name and any( + isinstance(name, str) for name in pdf.columns + ): + series_list = [(pdf[field.name], field.type) for field in arrow_return_type] + else: + series_list = [ + (pdf[pdf.columns[i]].rename(field.name), field.type) + for i, field in enumerate(arrow_return_type) + ] + batch = PandasBatchTransformer.to_arrow( + series_list, + timezone=self._timezone, + safecheck=self._safecheck, + int_to_decimal_coercion_enabled=self._int_to_decimal_coercion_enabled, + arrow_cast=self._arrow_cast, + ) + # Wrap columns as struct for JVM compatibility + yield ArrowBatchTransformer.wrap_struct(batch) - super().dump_stream(flatten_iterator(), stream) + # Write START_ARROW_STREAM marker before first batch + batches = self._write_stream_start(create_batches(), stream) + ArrowStreamSerializer.dump_stream(self, batches, stream) class TransformWithStateInPandasInitStateSerializer(TransformWithStateInPandasSerializer): @@ -1795,29 +1043,36 @@ def flatten_columns(cur_batch, col_name): but each batch will have either init_data or input_data, not mix. """ + def to_pandas(table): + return ArrowBatchTransformer.to_pandas( + table, + timezone=self._timezone, + schema=self._input_type, + struct_in_pandas=self._struct_in_pandas, + ndarray_as_list=self._ndarray_as_list, + df_for_struct=self._df_for_struct, + ) + def row_stream(): for batch in batches: self._update_batch_size_stats(batch) - flatten_state_table = flatten_columns(batch, "inputData") - data_pandas = [ - self.arrow_to_pandas(c, i) - for i, c in enumerate(flatten_state_table.itercolumns()) - ] + data_table = flatten_columns(batch, "inputData") + init_table = flatten_columns(batch, "initState") - flatten_init_table = flatten_columns(batch, "initState") - init_data_pandas = [ - self.arrow_to_pandas(c, i) - for i, c in enumerate(flatten_init_table.itercolumns()) - ] + # Check column count - empty table has no columns + has_data = data_table.num_columns > 0 + has_init = init_table.num_columns > 0 - assert not (bool(init_data_pandas) and bool(data_pandas)) + assert not (has_data and has_init) - if bool(data_pandas): + if has_data: + data_pandas = to_pandas(data_table) for row in pd.concat(data_pandas, axis=1).itertuples(index=False): batch_key = tuple(row[s] for s in self.key_offsets) yield (batch_key, row, None) - elif bool(init_data_pandas): + elif has_init: + init_data_pandas = to_pandas(init_table) for row in pd.concat(init_data_pandas, axis=1).itertuples(index=False): batch_key = tuple(row[s] for s in self.init_key_offsets) yield (batch_key, None, row) @@ -1855,7 +1110,9 @@ def row_stream(): else EMPTY_DATAFRAME.copy(), ) - _batches = super(ArrowStreamPandasSerializer, self).load_stream(stream) + # Call ArrowStreamSerializer.load_stream directly to get raw Arrow batches + # (not the parent's load_stream which returns processed data with mode info) + _batches = ArrowStreamSerializer.load_stream(self, stream) data_batches = generate_data_batches(_batches) for k, g in groupby(data_batches, key=lambda x: x[0]): @@ -1864,209 +1121,3 @@ def row_stream(): yield (TransformWithStateInPandasFuncMode.PROCESS_TIMER, None, None) yield (TransformWithStateInPandasFuncMode.COMPLETE, None, None) - - -class TransformWithStateInPySparkRowSerializer(ArrowStreamUDFSerializer): - """ - Serializer used by Python worker to evaluate UDF for - :meth:`pyspark.sql.GroupedData.transformWithState`. - - Parameters - ---------- - arrow_max_records_per_batch : int - Limit of the number of records that can be written to a single ArrowRecordBatch in memory. - """ - - def __init__(self, arrow_max_records_per_batch): - super().__init__() - self.arrow_max_records_per_batch = ( - arrow_max_records_per_batch if arrow_max_records_per_batch > 0 else 2**31 - 1 - ) - self.key_offsets = None - - def load_stream(self, stream): - """ - Read ArrowRecordBatches from stream, deserialize them to populate a list of data chunks, - and convert the data into a list of pandas.Series. - - Please refer the doc of inner function `generate_data_batches` for more details how - this function works in overall. - """ - from pyspark.sql.streaming.stateful_processor_util import ( - TransformWithStateInPandasFuncMode, - ) - import itertools - - def generate_data_batches(batches): - """ - Deserialize ArrowRecordBatches and return a generator of Row. - - The deserialization logic assumes that Arrow RecordBatches contain the data with the - ordering that data chunks for same grouping key will appear sequentially. - - This function must avoid materializing multiple Arrow RecordBatches into memory at the - same time. And data chunks from the same grouping key should appear sequentially. - """ - for batch in batches: - DataRow = Row(*batch.schema.names) - - # Iterate row by row without converting the whole batch - num_cols = batch.num_columns - for row_idx in range(batch.num_rows): - # build the key for this row - row_key = tuple(batch[o][row_idx].as_py() for o in self.key_offsets) - row = DataRow(*(batch.column(i)[row_idx].as_py() for i in range(num_cols))) - yield row_key, row - - _batches = super(ArrowStreamUDFSerializer, self).load_stream(stream) - data_batches = generate_data_batches(_batches) - - for k, g in groupby(data_batches, key=lambda x: x[0]): - chained = itertools.chain(g) - chained_values = map(lambda x: x[1], chained) - yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k, chained_values) - - yield (TransformWithStateInPandasFuncMode.PROCESS_TIMER, None, None) - - yield (TransformWithStateInPandasFuncMode.COMPLETE, None, None) - - def dump_stream(self, iterator, stream): - """ - Read through an iterator of (iterator of Row), serialize them to Arrow - RecordBatches, and write batches to stream. - """ - import pyarrow as pa - - def flatten_iterator(): - # iterator: iter[list[(iter[Row], pdf_type)]] - for packed in iterator: - iter_row_with_type = packed[0] - iter_row = iter_row_with_type[0] - pdf_type = iter_row_with_type[1] - - rows_as_dict = [] - for row in iter_row: - row_as_dict = row.asDict(True) - rows_as_dict.append(row_as_dict) - - pdf_schema = pa.schema(list(pdf_type)) - record_batch = pa.RecordBatch.from_pylist(rows_as_dict, schema=pdf_schema) - - yield (record_batch, pdf_type) - - return ArrowStreamUDFSerializer.dump_stream(self, flatten_iterator(), stream) - - -class TransformWithStateInPySparkRowInitStateSerializer(TransformWithStateInPySparkRowSerializer): - """ - Serializer used by Python worker to evaluate UDF for - :meth:`pyspark.sql.GroupedData.transformWithStateInPySparkRowInitStateSerializer`. - Parameters - ---------- - Same as input parameters in TransformWithStateInPySparkRowSerializer. - """ - - def __init__(self, arrow_max_records_per_batch): - super().__init__(arrow_max_records_per_batch) - self.init_key_offsets = None - - def load_stream(self, stream): - import pyarrow as pa - from pyspark.sql.streaming.stateful_processor_util import ( - TransformWithStateInPandasFuncMode, - ) - from typing import Iterator, Any, Optional, Tuple - - def generate_data_batches(batches) -> Iterator[Tuple[Any, Optional[Any], Optional[Any]]]: - """ - Deserialize ArrowRecordBatches and return a generator of Row. - The deserialization logic assumes that Arrow RecordBatches contain the data with the - ordering that data chunks for same grouping key will appear sequentially. - See `TransformWithStateInPySparkPythonInitialStateRunner` for arrow batch schema sent - from JVM. - This function flattens the columns of input rows and initial state rows and feed them - into the data generator. - """ - - def extract_rows( - cur_batch, col_name, key_offsets - ) -> Optional[Iterator[Tuple[Any, Any]]]: - data_column = cur_batch.column(cur_batch.schema.get_field_index(col_name)) - - # Check if the entire column is null - if data_column.null_count == len(data_column): - return None - - data_field_names = [ - data_column.type[i].name for i in range(data_column.type.num_fields) - ] - data_field_arrays = [ - data_column.field(i) for i in range(data_column.type.num_fields) - ] - - DataRow = Row(*data_field_names) - - table = pa.Table.from_arrays(data_field_arrays, names=data_field_names) - - if table.num_rows == 0: - return None - - def row_iterator(): - for row_idx in range(table.num_rows): - key = tuple(table.column(o)[row_idx].as_py() for o in key_offsets) - row = DataRow( - *(table.column(i)[row_idx].as_py() for i in range(table.num_columns)) - ) - yield (key, row) - - return row_iterator() - - """ - The arrow batch is written in the schema: - schema: StructType = new StructType() - .add("inputData", dataSchema) - .add("initState", initStateSchema) - We'll parse batch into Tuples of (key, inputData, initState) and pass into the Python - data generator. Each batch will have either init_data or input_data, not mix. - """ - for batch in batches: - # Detect which column has data - each batch contains only one type - input_result = extract_rows(batch, "inputData", self.key_offsets) - init_result = extract_rows(batch, "initState", self.init_key_offsets) - - assert not (input_result is not None and init_result is not None) - - if input_result is not None: - for key, input_data_row in input_result: - yield (key, input_data_row, None) - elif init_result is not None: - for key, init_state_row in init_result: - yield (key, None, init_state_row) - - _batches = super(ArrowStreamUDFSerializer, self).load_stream(stream) - data_batches = generate_data_batches(_batches) - - for k, g in groupby(data_batches, key=lambda x: x[0]): - input_rows = [] - init_rows = [] - - for batch_key, input_row, init_row in g: - if input_row is not None: - input_rows.append(input_row) - if init_row is not None: - init_rows.append(init_row) - - total_len = len(input_rows) + len(init_rows) - if total_len >= self.arrow_max_records_per_batch: - ret_tuple = (iter(input_rows), iter(init_rows)) - yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k, ret_tuple) - input_rows = [] - init_rows = [] - - if input_rows or init_rows: - ret_tuple = (iter(input_rows), iter(init_rows)) - yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k, ret_tuple) - - yield (TransformWithStateInPandasFuncMode.PROCESS_TIMER, None, None) - - yield (TransformWithStateInPandasFuncMode.COMPLETE, None, None) diff --git a/python/pyspark/sql/tests/test_conversion.py b/python/pyspark/sql/tests/test_conversion.py index c3fa1fd193042..99e6a4f403514 100644 --- a/python/pyspark/sql/tests/test_conversion.py +++ b/python/pyspark/sql/tests/test_conversion.py @@ -24,6 +24,8 @@ LocalDataToArrowConversion, ArrowTimestampConversion, ArrowBatchTransformer, + PandasBatchTransformer, + PandasSeriesToArrowConversion, ) from pyspark.sql.types import ( ArrayType, @@ -143,6 +145,500 @@ def test_wrap_struct_empty_batch(self): self.assertEqual(wrapped.num_rows, 0) self.assertEqual(wrapped.num_columns, 1) + def test_concat_batches_basic(self): + """Test concatenating multiple batches vertically.""" + import pyarrow as pa + + batch1 = pa.RecordBatch.from_arrays([pa.array([1, 2])], ["x"]) + batch2 = pa.RecordBatch.from_arrays([pa.array([3, 4])], ["x"]) + + result = ArrowBatchTransformer.concat_batches([batch1, batch2]) + + self.assertEqual(result.num_rows, 4) + self.assertEqual(result.column(0).to_pylist(), [1, 2, 3, 4]) + + def test_concat_batches_single(self): + """Test concatenating a single batch returns the same batch.""" + import pyarrow as pa + + batch = pa.RecordBatch.from_arrays([pa.array([1, 2, 3])], ["x"]) + + result = ArrowBatchTransformer.concat_batches([batch]) + + self.assertEqual(result.num_rows, 3) + self.assertEqual(result.column(0).to_pylist(), [1, 2, 3]) + + def test_concat_batches_empty_batches(self): + """Test concatenating empty batches.""" + import pyarrow as pa + + schema = pa.schema([("x", pa.int64())]) + batch1 = pa.RecordBatch.from_arrays([pa.array([], type=pa.int64())], schema=schema) + batch2 = pa.RecordBatch.from_arrays([pa.array([1, 2])], ["x"]) + + result = ArrowBatchTransformer.concat_batches([batch1, batch2]) + + self.assertEqual(result.num_rows, 2) + self.assertEqual(result.column(0).to_pylist(), [1, 2]) + + def test_zip_batches_record_batches(self): + """Test zipping multiple RecordBatches horizontally.""" + import pyarrow as pa + + batch1 = pa.RecordBatch.from_arrays([pa.array([1, 2])], ["a"]) + batch2 = pa.RecordBatch.from_arrays([pa.array(["x", "y"])], ["b"]) + + result = ArrowBatchTransformer.zip_batches([batch1, batch2]) + + self.assertEqual(result.num_columns, 2) + self.assertEqual(result.num_rows, 2) + self.assertEqual(result.column(0).to_pylist(), [1, 2]) + self.assertEqual(result.column(1).to_pylist(), ["x", "y"]) + + def test_zip_batches_arrays(self): + """Test zipping Arrow arrays directly.""" + import pyarrow as pa + + arr1 = pa.array([1, 2, 3]) + arr2 = pa.array(["a", "b", "c"]) + + result = ArrowBatchTransformer.zip_batches([arr1, arr2]) + + self.assertEqual(result.num_columns, 2) + self.assertEqual(result.column(0).to_pylist(), [1, 2, 3]) + self.assertEqual(result.column(1).to_pylist(), ["a", "b", "c"]) + + def test_zip_batches_with_type_casting(self): + """Test zipping with type casting.""" + import pyarrow as pa + + arr = pa.array([1, 2, 3], type=pa.int32()) + result = ArrowBatchTransformer.zip_batches([(arr, pa.int64())]) + + self.assertEqual(result.column(0).type, pa.int64()) + self.assertEqual(result.column(0).to_pylist(), [1, 2, 3]) + + def test_zip_batches_empty_raises(self): + """Test that zipping empty list raises error.""" + with self.assertRaises(PySparkValueError): + ArrowBatchTransformer.zip_batches([]) + + def test_zip_batches_single_batch(self): + """Test zipping single batch returns it unchanged.""" + import pyarrow as pa + + batch = pa.RecordBatch.from_arrays([pa.array([1, 2])], ["x"]) + result = ArrowBatchTransformer.zip_batches([batch]) + + self.assertIs(result, batch) + + def test_to_pandas_basic(self): + """Test basic Arrow to pandas conversion.""" + import pyarrow as pa + + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2, 3]), pa.array(["a", "b", "c"])], + names=["x", "y"], + ) + + result = ArrowBatchTransformer.to_pandas(batch, timezone="UTC") + + self.assertEqual(len(result), 2) + self.assertEqual(result[0].tolist(), [1, 2, 3]) + self.assertEqual(result[1].tolist(), ["a", "b", "c"]) + + def test_to_pandas_empty_table(self): + """Test converting empty table returns NoValue series.""" + import pyarrow as pa + + table = pa.Table.from_arrays([], names=[]) + + result = ArrowBatchTransformer.to_pandas(table, timezone="UTC") + + # Empty table with rows should return a series with _NoValue + self.assertEqual(len(result), 1) + self.assertEqual(len(result[0]), 0) + + def test_to_pandas_with_nulls(self): + """Test conversion with null values.""" + import pyarrow as pa + + batch = pa.RecordBatch.from_arrays( + [pa.array([1, None, 3]), pa.array([None, "b", None])], + names=["x", "y"], + ) + + result = ArrowBatchTransformer.to_pandas(batch, timezone="UTC") + + self.assertEqual(len(result), 2) + self.assertTrue(result[0].isna().tolist() == [False, True, False]) + self.assertTrue(result[1].isna().tolist() == [True, False, True]) + + def test_to_pandas_with_schema_and_udt(self): + """Test conversion with Spark schema containing UDT.""" + import pyarrow as pa + + # Create Arrow batch with raw UDT data (list representation) + batch = pa.RecordBatch.from_arrays( + [pa.array([[1.0, 2.0], [3.0, 4.0]], type=pa.list_(pa.float64()))], + names=["point"], + ) + + schema = StructType([StructField("point", ExamplePointUDT())]) + + result = ArrowBatchTransformer.to_pandas(batch, timezone="UTC", schema=schema) + + self.assertEqual(len(result), 1) + self.assertEqual(result[0][0], ExamplePoint(1.0, 2.0)) + self.assertEqual(result[0][1], ExamplePoint(3.0, 4.0)) + + def test_flatten_and_wrap_roundtrip(self): + """Test that flatten -> wrap produces equivalent result.""" + import pyarrow as pa + + struct_array = pa.StructArray.from_arrays( + [pa.array([1, 2]), pa.array(["a", "b"])], + names=["x", "y"], + ) + original = pa.RecordBatch.from_arrays([struct_array], ["_0"]) + + flattened = ArrowBatchTransformer.flatten_struct(original) + rewrapped = ArrowBatchTransformer.wrap_struct(flattened) + + self.assertEqual(rewrapped.num_columns, 1) + self.assertEqual(rewrapped.column(0).to_pylist(), original.column(0).to_pylist()) + + def test_enforce_schema_reorder(self): + """Test reordering columns to match target schema.""" + import pyarrow as pa + from pyspark.sql.types import LongType + + # Batch with columns in different order than target + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2]), pa.array(["a", "b"])], + names=["b", "a"], + ) + target = StructType( + [ + StructField("a", StringType()), + StructField("b", LongType()), + ] + ) + + result = ArrowBatchTransformer.enforce_schema(batch, target) + + self.assertEqual(result.schema.names, ["a", "b"]) + self.assertEqual(result.column(0).to_pylist(), ["a", "b"]) + self.assertEqual(result.column(1).to_pylist(), [1, 2]) + + def test_enforce_schema_coerce_types(self): + """Test coercing column types to match target schema.""" + import pyarrow as pa + from pyspark.sql.types import LongType + + # Batch with int32 that should be coerced to int64 (LongType) + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2], type=pa.int32())], + names=["x"], + ) + target = StructType([StructField("x", LongType())]) + + result = ArrowBatchTransformer.enforce_schema(batch, target) + + self.assertEqual(result.column(0).type, pa.int64()) + self.assertEqual(result.column(0).to_pylist(), [1, 2]) + + def test_enforce_schema_reorder_and_coerce(self): + """Test both reordering and type coercion together.""" + import pyarrow as pa + from pyspark.sql.types import LongType + + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2], type=pa.int32()), pa.array(["x", "y"])], + names=["b", "a"], + ) + target = StructType( + [ + StructField("a", StringType()), + StructField("b", LongType()), + ] + ) + + result = ArrowBatchTransformer.enforce_schema(batch, target) + + self.assertEqual(result.schema.names, ["a", "b"]) + self.assertEqual(result.column(0).to_pylist(), ["x", "y"]) + self.assertEqual(result.column(1).type, pa.int64()) + self.assertEqual(result.column(1).to_pylist(), [1, 2]) + + def test_enforce_schema_empty_batch(self): + """Test that empty batch is returned as-is.""" + import pyarrow as pa + from pyspark.sql.types import LongType + + batch = pa.RecordBatch.from_arrays([], names=[]) + target = StructType([StructField("x", LongType())]) + + result = ArrowBatchTransformer.enforce_schema(batch, target) + + self.assertEqual(result.num_columns, 0) + + def test_enforce_schema_with_large_var_types(self): + """Test using prefer_large_var_types option.""" + import pyarrow as pa + + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2]), pa.array(["a", "b"])], + names=["b", "a"], + ) + target = StructType( + [ + StructField("a", StringType()), + StructField("b", IntegerType()), + ] + ) + + result = ArrowBatchTransformer.enforce_schema(batch, target, prefer_large_var_types=True) + + self.assertEqual(result.schema.names, ["a", "b"]) + # With prefer_large_var_types=True, string should be large_string + self.assertEqual(result.column(0).type, pa.large_string()) + self.assertEqual(result.column(0).to_pylist(), ["a", "b"]) + self.assertEqual(result.column(1).to_pylist(), [1, 2]) + + +@unittest.skipIf(not have_pyarrow, pyarrow_requirement_message) +class PandasBatchTransformerTests(unittest.TestCase): + def test_to_arrow_single_series(self): + """Test converting a single pandas Series to Arrow.""" + import pandas as pd + import pyarrow as pa + + series = pd.Series([1, 2, 3]) + result = PandasBatchTransformer.to_arrow( + (series, pa.int64(), IntegerType()), timezone="UTC" + ) + + self.assertEqual(result.num_columns, 1) + self.assertEqual(result.column(0).to_pylist(), [1, 2, 3]) + + def test_to_arrow_empty_series(self): + """Test converting empty pandas Series.""" + import pandas as pd + import pyarrow as pa + + series = pd.Series([], dtype="int64") + result = PandasBatchTransformer.to_arrow( + (series, pa.int64(), IntegerType()), timezone="UTC" + ) + + self.assertEqual(result.num_rows, 0) + + def test_to_arrow_with_nulls(self): + """Test converting Series with null values.""" + import pandas as pd + import pyarrow as pa + import numpy as np + from pyspark.sql.types import DoubleType + + series = pd.Series([1.0, np.nan, 3.0]) + result = PandasBatchTransformer.to_arrow( + (series, pa.float64(), DoubleType()), timezone="UTC" + ) + + self.assertEqual(result.column(0).to_pylist(), [1.0, None, 3.0]) + + def test_to_arrow_multiple_series(self): + """Test converting multiple series at once.""" + import pandas as pd + import pyarrow as pa + + series1 = pd.Series([1, 2]) + series2 = pd.Series(["a", "b"]) + + result = PandasBatchTransformer.to_arrow( + [(series1, pa.int64(), None), (series2, pa.string(), None)], timezone="UTC" + ) + + self.assertEqual(result.num_columns, 2) + self.assertEqual(result.column(0).to_pylist(), [1, 2]) + self.assertEqual(result.column(1).to_pylist(), ["a", "b"]) + + def test_to_arrow_struct_mode(self): + """Test converting DataFrame to struct array.""" + import pandas as pd + import pyarrow as pa + + df = pd.DataFrame({"x": [1, 2], "y": ["a", "b"]}) + arrow_type = pa.struct([("x", pa.int64()), ("y", pa.string())]) + spark_type = StructType([StructField("x", IntegerType()), StructField("y", StringType())]) + + result = PandasBatchTransformer.to_arrow( + (df, arrow_type, spark_type), timezone="UTC", as_struct=True + ) + + self.assertEqual(result.num_columns, 1) + struct_col = result.column(0) + self.assertEqual(struct_col.field(0).to_pylist(), [1, 2]) + self.assertEqual(struct_col.field(1).to_pylist(), ["a", "b"]) + + def test_to_arrow_struct_empty_dataframe(self): + """Test converting empty DataFrame in struct mode.""" + import pandas as pd + import pyarrow as pa + + df = pd.DataFrame({"x": pd.Series([], dtype="int64"), "y": pd.Series([], dtype="str")}) + arrow_type = pa.struct([("x", pa.int64()), ("y", pa.string())]) + spark_type = StructType([StructField("x", IntegerType()), StructField("y", StringType())]) + + result = PandasBatchTransformer.to_arrow( + (df, arrow_type, spark_type), timezone="UTC", as_struct=True + ) + + self.assertEqual(result.num_rows, 0) + + def test_to_arrow_categorical_series(self): + """Test converting categorical pandas Series.""" + import pandas as pd + import pyarrow as pa + + series = pd.Categorical(["a", "b", "a", "c"]) + result = PandasBatchTransformer.to_arrow( + (pd.Series(series), pa.string(), StringType()), timezone="UTC" + ) + + self.assertEqual(result.column(0).to_pylist(), ["a", "b", "a", "c"]) + + def test_to_arrow_requires_dataframe_for_struct(self): + """Test that struct mode requires DataFrame input.""" + import pandas as pd + import pyarrow as pa + + series = pd.Series([1, 2]) + arrow_type = pa.struct([("x", pa.int64())]) + + with self.assertRaises(PySparkValueError): + PandasBatchTransformer.to_arrow( + (series, arrow_type, None), timezone="UTC", as_struct=True + ) + + +@unittest.skipIf(not have_pyarrow, pyarrow_requirement_message) +class PandasSeriesToArrowConversionTests(unittest.TestCase): + def test_create_array_basic(self): + """Test basic array creation from pandas Series.""" + import pandas as pd + import pyarrow as pa + + series = pd.Series([1, 2, 3]) + result = PandasSeriesToArrowConversion.create_array( + series, pa.int64(), "UTC", spark_type=IntegerType() + ) + + self.assertEqual(result.to_pylist(), [1, 2, 3]) + + def test_create_array_empty(self): + """Test creating array from empty Series.""" + import pandas as pd + import pyarrow as pa + + series = pd.Series([], dtype="int64") + result = PandasSeriesToArrowConversion.create_array( + series, pa.int64(), "UTC", spark_type=IntegerType() + ) + + self.assertEqual(len(result), 0) + + def test_create_array_with_nulls(self): + """Test creating array with null values.""" + import pandas as pd + import pyarrow as pa + + series = pd.Series([1, None, 3], dtype="Int64") # nullable integer + result = PandasSeriesToArrowConversion.create_array( + series, pa.int64(), "UTC", spark_type=IntegerType() + ) + + self.assertEqual(result.to_pylist(), [1, None, 3]) + + def test_create_array_categorical(self): + """Test creating array from categorical Series.""" + import pandas as pd + import pyarrow as pa + + series = pd.Series(pd.Categorical(["x", "y", "x"])) + result = PandasSeriesToArrowConversion.create_array( + series, pa.string(), "UTC", spark_type=StringType() + ) + + self.assertEqual(result.to_pylist(), ["x", "y", "x"]) + + def test_create_array_with_udt(self): + """Test creating array with UDT.""" + import pandas as pd + import pyarrow as pa + + points = [ExamplePoint(1.0, 2.0), ExamplePoint(3.0, 4.0)] + series = pd.Series(points) + + result = PandasSeriesToArrowConversion.create_array( + series, pa.list_(pa.float64()), "UTC", spark_type=ExamplePointUDT() + ) + + self.assertEqual(result.to_pylist(), [[1.0, 2.0], [3.0, 4.0]]) + + def test_create_array_binary(self): + """Test creating array with binary data.""" + import pandas as pd + import pyarrow as pa + + series = pd.Series([b"hello", b"world"]) + result = PandasSeriesToArrowConversion.create_array( + series, pa.binary(), "UTC", spark_type=BinaryType() + ) + + self.assertEqual(result.to_pylist(), [b"hello", b"world"]) + + def test_create_array_all_nulls(self): + """Test creating array with all null values.""" + import pandas as pd + import pyarrow as pa + + series = pd.Series([None, None, None]) + result = PandasSeriesToArrowConversion.create_array( + series, pa.int64(), "UTC", spark_type=IntegerType() + ) + + self.assertEqual(result.to_pylist(), [None, None, None]) + + def test_create_array_nested_list(self): + """Test creating array with nested list type.""" + import pandas as pd + import pyarrow as pa + + series = pd.Series([[1, 2], [3, 4, 5], []]) + result = PandasSeriesToArrowConversion.create_array( + series, pa.list_(pa.int64()), "UTC", spark_type=ArrayType(IntegerType()) + ) + + self.assertEqual(result.to_pylist(), [[1, 2], [3, 4, 5], []]) + + def test_create_array_map_type(self): + """Test creating array with map type.""" + import pandas as pd + import pyarrow as pa + + series = pd.Series([{"a": 1}, {"b": 2, "c": 3}]) + result = PandasSeriesToArrowConversion.create_array( + series, + pa.map_(pa.string(), pa.int64()), + "UTC", + spark_type=MapType(StringType(), IntegerType()), + ) + + # Maps are returned as list of tuples + self.assertEqual(result.to_pylist(), [[("a", 1)], [("b", 2), ("c", 3)]]) + @unittest.skipIf(not have_pyarrow, pyarrow_requirement_message) class ConversionTests(unittest.TestCase): diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 03bc1366e875d..34fee1aaa0ec6 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -50,28 +50,17 @@ from pyspark.sql.conversion import ( LocalDataToArrowConversion, ArrowTableToRowsConversion, + ArrowArrayToPandasConversion, ArrowBatchTransformer, + PandasBatchTransformer, ) from pyspark.sql.functions import SkipRestOfInputTableException from pyspark.sql.pandas.serializers import ( - ArrowStreamPandasUDFSerializer, - ArrowStreamPandasUDTFSerializer, - ArrowStreamGroupUDFSerializer, - GroupPandasUDFSerializer, - CogroupArrowUDFSerializer, - CogroupPandasUDFSerializer, - ArrowStreamUDFSerializer, + ArrowStreamGroupSerializer, ApplyInPandasWithStateSerializer, TransformWithStateInPandasSerializer, TransformWithStateInPandasInitStateSerializer, - TransformWithStateInPySparkRowSerializer, - TransformWithStateInPySparkRowInitStateSerializer, - ArrowStreamArrowUDFSerializer, - ArrowStreamAggPandasUDFSerializer, - ArrowStreamAggArrowUDFSerializer, ArrowBatchUDFSerializer, - ArrowStreamUDTFSerializer, - ArrowStreamArrowUDTFSerializer, ) from pyspark.sql.pandas.types import to_arrow_type, TimestampType from pyspark.sql.types import ( @@ -228,84 +217,111 @@ def wrap_udf(f, args_offsets, kwargs_offsets, return_type): return args_kwargs_offsets, lambda *a: func(*a) +# ============================================================ +# Common verification functions for UDF wrappers +# ============================================================ + + +def verify_result_length(result, expected_length, udf_type): + """Verify that the result length matches the expected length.""" + if len(result) != expected_length: + raise PySparkRuntimeError( + errorClass="SCHEMA_MISMATCH_FOR_PANDAS_UDF", + messageParameters={ + "udf_type": udf_type, + "expected": str(expected_length), + "actual": str(len(result)), + }, + ) + return result + + +def verify_result_type(result, expected_type_name): + """Verify that the result has __len__ attribute (is array-like).""" + if not hasattr(result, "__len__"): + raise PySparkTypeError( + errorClass="UDF_RETURN_TYPE", + messageParameters={ + "expected": expected_type_name, + "actual": type(result).__name__, + }, + ) + return result + + +def verify_is_iterable(result, expected_type_name): + """Verify that the result is an iterator or iterable.""" + if not isinstance(result, Iterator) and not hasattr(result, "__iter__"): + raise PySparkTypeError( + errorClass="UDF_RETURN_TYPE", + messageParameters={ + "expected": expected_type_name, + "actual": type(result).__name__, + }, + ) + return result + + +def verify_element_type(elem, expected_class, expected_type_name): + """Verify that an element has the expected type.""" + if not isinstance(elem, expected_class): + raise PySparkTypeError( + errorClass="UDF_RETURN_TYPE", + messageParameters={ + "expected": expected_type_name, + "actual": "iterator of {}".format(type(elem).__name__), + }, + ) + return elem + + def wrap_scalar_pandas_udf(f, args_offsets, kwargs_offsets, return_type, runner_conf): + """Wrap a scalar pandas UDF. + + The wrapper only validates and calls the user function. + Output conversion (pandas → Arrow) is done in the mapper. + + Returns: + (arg_offsets, wrapped_func) where wrapped_func returns (result, arrow_return_type) + """ func, args_kwargs_offsets = wrap_kwargs_support(f, args_offsets, kwargs_offsets) arrow_return_type = to_arrow_type( return_type, timezone="UTC", prefers_large_types=runner_conf.use_large_var_types ) + pd_type = "pandas.DataFrame" if type(return_type) == StructType else "pandas.Series" - def verify_result_type(result): - if not hasattr(result, "__len__"): - pd_type = "pandas.DataFrame" if type(return_type) == StructType else "pandas.Series" - raise PySparkTypeError( - errorClass="UDF_RETURN_TYPE", - messageParameters={ - "expected": pd_type, - "actual": type(result).__name__, - }, - ) - return result + def wrapped(*args): + result = func(*args) + verify_result_type(result, pd_type) + verify_result_length(result, len(args[0]), "pandas_udf") + return (result, arrow_return_type) - def verify_result_length(result, length): - if len(result) != length: - raise PySparkRuntimeError( - errorClass="SCHEMA_MISMATCH_FOR_PANDAS_UDF", - messageParameters={ - "udf_type": "pandas_udf", - "expected": str(length), - "actual": str(len(result)), - }, - ) - return result - - return ( - args_kwargs_offsets, - lambda *a: ( - verify_result_length(verify_result_type(func(*a)), len(a[0])), - arrow_return_type, - ), - ) + return (args_kwargs_offsets, wrapped) def wrap_scalar_arrow_udf(f, args_offsets, kwargs_offsets, return_type, runner_conf): + """Wrap a scalar arrow UDF. + + The wrapper only validates and calls the user function. + Output conversion is done in the mapper. + + Returns: + (arg_offsets, wrapped_func) where wrapped_func returns (result, arrow_return_type) + """ func, args_kwargs_offsets = wrap_kwargs_support(f, args_offsets, kwargs_offsets) arrow_return_type = to_arrow_type( return_type, timezone="UTC", prefers_large_types=runner_conf.use_large_var_types ) - def verify_result_type(result): - if not hasattr(result, "__len__"): - pd_type = "pyarrow.Array" - raise PySparkTypeError( - errorClass="UDF_RETURN_TYPE", - messageParameters={ - "expected": pd_type, - "actual": type(result).__name__, - }, - ) - return result + def wrapped(*args): + result = func(*args) + verify_result_type(result, "pyarrow.Array") + verify_result_length(result, len(args[0]), "arrow_udf") + return (result, arrow_return_type) - def verify_result_length(result, length): - if len(result) != length: - raise PySparkRuntimeError( - errorClass="SCHEMA_MISMATCH_FOR_PANDAS_UDF", - messageParameters={ - "udf_type": "arrow_udf", - "expected": str(length), - "actual": str(len(result)), - }, - ) - return result - - return ( - args_kwargs_offsets, - lambda *a: ( - verify_result_length(verify_result_type(func(*a)), len(a[0])), - arrow_return_type, - ), - ) + return (args_kwargs_offsets, wrapped) def wrap_arrow_batch_udf(f, args_offsets, kwargs_offsets, return_type, runner_conf): @@ -450,43 +466,36 @@ def verify_result_length(result, length): def wrap_pandas_batch_iter_udf(f, return_type, runner_conf): + """Wrap a pandas batch iter UDF. + + The wrapper only validates and calls the user function. + Output conversion (pandas → Arrow) is done in func. + + Returns: + wrapped_func that yields (result, arrow_return_type) tuples + """ + import pandas as pd + arrow_return_type = to_arrow_type( return_type, timezone="UTC", prefers_large_types=runner_conf.use_large_var_types ) iter_type_label = "pandas.DataFrame" if type(return_type) == StructType else "pandas.Series" - - def verify_result(result): - if not isinstance(result, Iterator) and not hasattr(result, "__iter__"): - raise PySparkTypeError( - errorClass="UDF_RETURN_TYPE", - messageParameters={ - "expected": "iterator of {}".format(iter_type_label), - "actual": type(result).__name__, - }, - ) - return result + expected_class = pd.DataFrame if type(return_type) == StructType else pd.Series def verify_element(elem): - import pandas as pd - - if not isinstance(elem, pd.DataFrame if type(return_type) == StructType else pd.Series): - raise PySparkTypeError( - errorClass="UDF_RETURN_TYPE", - messageParameters={ - "expected": "iterator of {}".format(iter_type_label), - "actual": "iterator of {}".format(type(elem).__name__), - }, - ) - + verify_element_type(elem, expected_class, "iterator of {}".format(iter_type_label)) verify_pandas_result( elem, return_type, assign_cols_by_name=True, truncate_return_schema=True ) - return elem - return lambda *iterator: map( - lambda res: (res, arrow_return_type), map(verify_element, verify_result(f(*iterator))) - ) + def wrapped(*iterator): + result = f(*iterator) + verify_is_iterable(result, "iterator of {}".format(iter_type_label)) + # Yield (result, arrow_type) pairs - output conversion done in func + return ((verify_element(elem), arrow_return_type) for elem in result) + + return wrapped def verify_pandas_result(result, return_type, assign_cols_by_name, truncate_return_schema): @@ -506,7 +515,7 @@ def verify_pandas_result(result, return_type, assign_cols_by_name, truncate_retu if not result.empty or len(result.columns) != 0: # if any column name of the result is a string # the column names of the result have to match the return type - # see create_array in pyspark.sql.pandas.serializers.ArrowStreamPandasSerializer + # see create_array in pyspark.sql.pandas.serializers.ArrowStreamUDFSerializer field_names = set([field.name for field in return_type.fields]) # only the first len(field_names) result columns are considered # when truncating the return schema @@ -550,73 +559,48 @@ def verify_pandas_result(result, return_type, assign_cols_by_name, truncate_retu def wrap_arrow_array_iter_udf(f, return_type, runner_conf): - arrow_return_type = to_arrow_type( - return_type, timezone="UTC", prefers_large_types=runner_conf.use_large_var_types - ) - - def verify_result(result): - if not isinstance(result, Iterator) and not hasattr(result, "__iter__"): - raise PySparkTypeError( - errorClass="UDF_RETURN_TYPE", - messageParameters={ - "expected": "iterator of pyarrow.Array", - "actual": type(result).__name__, - }, - ) - return result - - def verify_element(elem): - import pyarrow as pa - - if not isinstance(elem, pa.Array): - raise PySparkTypeError( - errorClass="UDF_RETURN_TYPE", - messageParameters={ - "expected": "iterator of pyarrow.Array", - "actual": "iterator of {}".format(type(elem).__name__), - }, - ) + """Wrap an arrow array iter UDF. - return elem - - return lambda *iterator: map( - lambda res: (res, arrow_return_type), map(verify_element, verify_result(f(*iterator))) - ) + The wrapper only validates and calls the user function. + Output conversion is done in func. + Returns: + wrapped_func that yields (result, arrow_return_type) tuples + """ + import pyarrow as pa -def wrap_arrow_batch_iter_udf(f, return_type, runner_conf): arrow_return_type = to_arrow_type( return_type, timezone="UTC", prefers_large_types=runner_conf.use_large_var_types ) + iter_type = "iterator of pyarrow.Array" - def verify_result(result): - if not isinstance(result, Iterator) and not hasattr(result, "__iter__"): - raise PySparkTypeError( - errorClass="UDF_RETURN_TYPE", - messageParameters={ - "expected": "iterator of pyarrow.RecordBatch", - "actual": type(result).__name__, - }, - ) - return result + def wrapped(*iterator): + result = f(*iterator) + verify_is_iterable(result, iter_type) + # Yield (result, arrow_type) pairs - output conversion done in func + return ( + (verify_element_type(elem, pa.Array, iter_type), arrow_return_type) for elem in result + ) - def verify_element(elem): - import pyarrow as pa + return wrapped - if not isinstance(elem, pa.RecordBatch): - raise PySparkTypeError( - errorClass="UDF_RETURN_TYPE", - messageParameters={ - "expected": "iterator of pyarrow.RecordBatch", - "actual": "iterator of {}".format(type(elem).__name__), - }, - ) - return elem +def wrap_arrow_batch_iter_udf(f, return_type, runner_conf): + import pyarrow as pa - return lambda *iterator: map( - lambda res: (res, arrow_return_type), map(verify_element, verify_result(f(*iterator))) - ) + iter_type = "iterator of pyarrow.RecordBatch" + + # For mapInArrow: batches are already flattened in func before calling wrapper + # User function receives flattened batches and returns flattened batches + # We just need to wrap results back into struct for serialization + def wrapper(batch_iter): + result = f(batch_iter) + verify_is_iterable(result, iter_type) + for res in result: + verify_element_type(res, pa.RecordBatch, iter_type) + yield ArrowBatchTransformer.wrap_struct(res) + + return wrapper def wrap_cogrouped_map_arrow_udf(f, return_type, argspec, runner_conf): @@ -639,15 +623,24 @@ def wrapped(left_key_table, left_value_table, right_key_table, right_value_table verify_arrow_table(result, runner_conf.assign_cols_by_name, expected_cols_and_types) - return result.to_batches() + for batch in result.to_batches(): + batch = ArrowBatchTransformer.enforce_schema( + batch, + return_type, + prefer_large_var_types=runner_conf.use_large_var_types, + ) + yield ArrowBatchTransformer.wrap_struct(batch) - return lambda kl, vl, kr, vr: ( - wrapped(kl, vl, kr, vr), - to_arrow_type(return_type, timezone="UTC"), - ) + return lambda kl, vl, kr, vr: wrapped(kl, vl, kr, vr) def wrap_cogrouped_map_pandas_udf(f, return_type, argspec, runner_conf): + from pyspark.sql.conversion import PandasBatchTransformer + + arrow_return_type = to_arrow_type( + return_type, timezone="UTC", prefers_large_types=runner_conf.use_large_var_types + ) + def wrapped(left_key_series, left_value_series, right_key_series, right_value_series): import pandas as pd @@ -664,12 +657,19 @@ def wrapped(left_key_series, left_value_series, right_key_series, right_value_se result, return_type, runner_conf.assign_cols_by_name, truncate_return_schema=False ) - return result + # Convert pandas DataFrame to Arrow RecordBatch and yield (consistent with Arrow cogrouped) + batch = PandasBatchTransformer.to_arrow( + [(result, arrow_return_type)], + timezone=runner_conf.timezone, + safecheck=runner_conf.safecheck, + int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled, + struct_in_pandas="dict", + assign_cols_by_name=runner_conf.assign_cols_by_name, + arrow_cast=True, + ) + yield batch - arrow_return_type = to_arrow_type( - return_type, timezone="UTC", prefers_large_types=runner_conf.use_large_var_types - ) - return lambda kl, vl, kr, vr: [(wrapped(kl, vl, kr, vr), arrow_return_type)] + return wrapped def verify_arrow_result(result, assign_cols_by_name, expected_cols_and_types): @@ -785,12 +785,16 @@ def wrapped(key_batch, value_batches): verify_arrow_table(result, runner_conf.assign_cols_by_name, expected_cols_and_types) - yield from result.to_batches() + # Enforce schema (reorder + coerce) and wrap each batch into struct + for batch in result.to_batches(): + batch = ArrowBatchTransformer.enforce_schema( + batch, + return_type, + prefer_large_var_types=runner_conf.use_large_var_types, + ) + yield ArrowBatchTransformer.wrap_struct(batch) - arrow_return_type = to_arrow_type( - return_type, timezone="UTC", prefers_large_types=runner_conf.use_large_var_types - ) - return lambda k, v: (wrapped(k, v), arrow_return_type) + return lambda k, v: wrapped(k, v) def wrap_grouped_map_arrow_iter_udf(f, return_type, argspec, runner_conf): @@ -810,30 +814,25 @@ def wrapped(key_batch, value_batches): key = tuple(c[0] for c in key_batch.columns) result = f(key, value_batches) - def verify_element(batch): + for batch in result: verify_arrow_batch(batch, runner_conf.assign_cols_by_name, expected_cols_and_types) - return batch - - yield from map(verify_element, result) + batch = ArrowBatchTransformer.enforce_schema( + batch, + return_type, + prefer_large_var_types=runner_conf.use_large_var_types, + ) + yield ArrowBatchTransformer.wrap_struct(batch) - arrow_return_type = to_arrow_type( - return_type, timezone="UTC", prefers_large_types=runner_conf.use_large_var_types - ) - return lambda k, v: (wrapped(k, v), arrow_return_type) + return lambda k, v: wrapped(k, v) def wrap_grouped_map_pandas_udf(f, return_type, argspec, runner_conf): - def wrapped(key_series, value_batches): + def wrapped(key_series, value_dfs): import pandas as pd - # Convert value_batches (Iterator[list[pd.Series]]) to a single DataFrame - # Each value_series is a list of Series (one per column) for one batch - # Concatenate Series within each batch (axis=1), then concatenate batches (axis=0) - value_dataframes = [] - for value_series in value_batches: - value_dataframes.append(pd.concat(value_series, axis=1)) - - value_df = pd.concat(value_dataframes, axis=0) if value_dataframes else pd.DataFrame() + # Concat all DataFrames in the group vertically + dfs = list(value_dfs) + value_df = pd.concat(dfs, axis=0) if dfs else pd.DataFrame() if len(argspec.args) == 1: result = f(value_df) @@ -853,29 +852,57 @@ def wrapped(key_series, value_batches): ) def flatten_wrapper(k, v): - # Return Iterator[[(df, arrow_type)]] directly + # Convert pandas DataFrame to Arrow RecordBatch and wrap as struct for JVM + import pyarrow as pa + from pyspark.sql.conversion import PandasBatchTransformer, ArrowBatchTransformer + for df in wrapped(k, v): - yield [(df, arrow_return_type)] + # Handle empty DataFrame (no columns) + if len(df.columns) == 0: + # Create empty struct array with correct schema + # Use generic field names (_0, _1, ...) to match wrap_struct behavior + struct_fields = [ + pa.field(f"_{i}", field.type) for i, field in enumerate(arrow_return_type) + ] + struct_arr = pa.array([{}] * len(df), pa.struct(struct_fields)) + batch = pa.RecordBatch.from_arrays([struct_arr], ["_0"]) + yield batch + continue + + # Build list of (Series, arrow_type) tuples + # If assign_cols_by_name and DataFrame has named columns, match by name + # Otherwise, match by position + if runner_conf.assign_cols_by_name and any( + isinstance(name, str) for name in df.columns + ): + series_list = [(df[field.name], field.type) for field in arrow_return_type] + else: + series_list = [ + (df[df.columns[i]].rename(field.name), field.type) + for i, field in enumerate(arrow_return_type) + ] + batch = PandasBatchTransformer.to_arrow( + series_list, + timezone=runner_conf.timezone, + safecheck=runner_conf.safecheck, + int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled, + arrow_cast=True, + ) + # Wrap columns as struct for JVM compatibility + yield ArrowBatchTransformer.wrap_struct(batch) return flatten_wrapper def wrap_grouped_map_pandas_iter_udf(f, return_type, argspec, runner_conf): - def wrapped(key_series, value_batches): - import pandas as pd - - # value_batches is an Iterator[list[pd.Series]] (one list per batch) - # Convert each list of Series into a DataFrame - def dataframe_iter(): - for value_series in value_batches: - yield pd.concat(value_series, axis=1) - + def wrapped(key_series, value_dfs): + # value_dfs is already Iterator[pd.DataFrame] from mapper if len(argspec.args) == 1: - result = f(dataframe_iter()) + result = f(value_dfs) elif len(argspec.args) == 2: # Extract key from pandas Series, preserving numpy types key = tuple(s.iloc[0] for s in key_series) - result = f(key, dataframe_iter()) + result = f(key, value_dfs) def verify_element(df): verify_pandas_result( @@ -890,9 +917,44 @@ def verify_element(df): ) def flatten_wrapper(k, v): - # Return Iterator[[(df, arrow_type)]] directly + # Convert pandas DataFrames to Arrow RecordBatches and wrap as struct for JVM + import pyarrow as pa + from pyspark.sql.conversion import PandasBatchTransformer, ArrowBatchTransformer + for df in wrapped(k, v): - yield [(df, arrow_return_type)] + # Handle empty DataFrame (no columns) + if len(df.columns) == 0: + # Create empty struct array with correct schema + # Use generic field names (_0, _1, ...) to match wrap_struct behavior + struct_fields = [ + pa.field(f"_{i}", field.type) for i, field in enumerate(arrow_return_type) + ] + struct_arr = pa.array([{}] * len(df), pa.struct(struct_fields)) + batch = pa.RecordBatch.from_arrays([struct_arr], ["_0"]) + yield batch + continue + + # Build list of (Series, arrow_type) tuples + # If assign_cols_by_name and DataFrame has named columns, match by name + # Otherwise, match by position + if runner_conf.assign_cols_by_name and any( + isinstance(name, str) for name in df.columns + ): + series_list = [(df[field.name], field.type) for field in arrow_return_type] + else: + series_list = [ + (df[df.columns[i]].rename(field.name), field.type) + for i, field in enumerate(arrow_return_type) + ] + batch = PandasBatchTransformer.to_arrow( + series_list, + timezone=runner_conf.timezone, + safecheck=runner_conf.safecheck, + int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled, + arrow_cast=True, + ) + # Wrap columns as struct for JVM compatibility + yield ArrowBatchTransformer.wrap_struct(batch) return flatten_wrapper @@ -1068,6 +1130,8 @@ def verify_element(result): def wrap_grouped_agg_pandas_udf(f, args_offsets, kwargs_offsets, return_type, runner_conf): + from pyspark.sql.conversion import PandasBatchTransformer + func, args_kwargs_offsets = wrap_kwargs_support(f, args_offsets, kwargs_offsets) arrow_return_type = to_arrow_type( @@ -1078,12 +1142,18 @@ def wrapped(*series): import pandas as pd result = func(*series) - return pd.Series([result]) + result_series = pd.Series([result]) + + # Convert to Arrow RecordBatch in wrapper + return PandasBatchTransformer.to_arrow( + [(result_series, arrow_return_type)], + timezone=runner_conf.timezone, + safecheck=runner_conf.safecheck, + int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled, + arrow_cast=True, + ) - return ( - args_kwargs_offsets, - lambda *a: (wrapped(*a), arrow_return_type), - ) + return (args_kwargs_offsets, wrapped) def wrap_grouped_agg_arrow_udf(f, args_offsets, kwargs_offsets, return_type, runner_conf): @@ -1097,11 +1167,15 @@ def wrapped(*series): import pyarrow as pa result = func(*series) - return pa.array([result]) + array = pa.array([result]) + # Return RecordBatch directly instead of (array, type) tuple + return ArrowBatchTransformer.zip_batches( + [(array, arrow_return_type)], + ) return ( args_kwargs_offsets, - lambda *a: (wrapped(*a), arrow_return_type), + wrapped, ) @@ -1117,15 +1191,21 @@ def wrapped(batch_iter): # batch_iter: Iterator[pa.Array] (single) or Iterator[Tuple[pa.Array, ...]] (multiple) result = func(batch_iter) - return pa.array([result]) + array = pa.array([result]) + # Return RecordBatch directly instead of (array, type) tuple + return ArrowBatchTransformer.zip_batches( + [(array, arrow_return_type)], + ) return ( args_kwargs_offsets, - lambda *a: (wrapped(*a), arrow_return_type), + wrapped, ) def wrap_grouped_agg_pandas_iter_udf(f, args_offsets, kwargs_offsets, return_type, runner_conf): + from pyspark.sql.conversion import PandasBatchTransformer + func, args_kwargs_offsets = wrap_kwargs_support(f, args_offsets, kwargs_offsets) arrow_return_type = to_arrow_type( @@ -1139,12 +1219,18 @@ def wrapped(series_iter): # Iterator[Tuple[pd.Series, ...]] (multiple columns) # This has already been adapted by the mapper function in read_udfs result = func(series_iter) - return pd.Series([result]) + result_series = pd.Series([result]) + + # Convert to Arrow RecordBatch in wrapper + return PandasBatchTransformer.to_arrow( + [(result_series, arrow_return_type)], + timezone=runner_conf.timezone, + safecheck=runner_conf.safecheck, + int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled, + arrow_cast=True, + ) - return ( - args_kwargs_offsets, - lambda *a: (wrapped(*a), arrow_return_type), - ) + return (args_kwargs_offsets, wrapped) def wrap_window_agg_pandas_udf( @@ -1208,7 +1294,13 @@ def wrapped(*series): return ( args_kwargs_offsets, - lambda *a: (wrapped(*a), arrow_return_type), + lambda *a: PandasBatchTransformer.to_arrow( + [(wrapped(*a), arrow_return_type)], + timezone=runner_conf.timezone, + safecheck=runner_conf.safecheck, + int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled, + arrow_cast=True, + ), ) @@ -1229,7 +1321,9 @@ def wrapped(*series): return ( args_kwargs_offsets, - lambda *a: (wrapped(*a), arrow_return_type), + lambda *a: ArrowBatchTransformer.zip_batches( + [(wrapped(*a), arrow_return_type)], + ), ) @@ -1273,7 +1367,13 @@ def wrapped(begin_index, end_index, *series): return ( args_offsets[:2] + args_kwargs_offsets, - lambda *a: (wrapped(*a), arrow_return_type), + lambda *a: PandasBatchTransformer.to_arrow( + [(wrapped(*a), arrow_return_type)], + timezone=runner_conf.timezone, + safecheck=runner_conf.safecheck, + int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled, + arrow_cast=True, + ), ) @@ -1302,7 +1402,9 @@ def wrapped(begin_index, end_index, *series): return ( args_offsets[:2] + args_kwargs_offsets, - lambda *a: (wrapped(*a), arrow_return_type), + lambda *a: ArrowBatchTransformer.zip_batches( + [(wrapped(*a), arrow_return_type)], + ), ) @@ -1533,21 +1635,15 @@ def read_udtf(pickleSer, infile, eval_type, runner_conf): if eval_type == PythonEvalType.SQL_ARROW_TABLE_UDF: input_type = _parse_datatype_json_string(utf8_deserializer.loads(infile)) if runner_conf.use_legacy_pandas_udtf_conversion: - # NOTE: if timezone is set here, that implies respectSessionTimeZone is True - ser = ArrowStreamPandasUDTFSerializer( - runner_conf.timezone, - runner_conf.safecheck, - input_type=input_type, - int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled, - ) + # UDTF uses ArrowStreamGroupSerializer + ser = ArrowStreamGroupSerializer() else: - ser = ArrowStreamUDTFSerializer() + ser = ArrowStreamGroupSerializer() elif eval_type == PythonEvalType.SQL_ARROW_UDTF: - # Read the table argument offsets + # Read the table argument offsets (used by mapper to flatten struct columns) num_table_arg_offsets = read_int(infile) table_arg_offsets = [read_int(infile) for _ in range(num_table_arg_offsets)] - # Use PyArrow-native serializer for Arrow UDTFs with potential UDT support - ser = ArrowStreamArrowUDTFSerializer(table_arg_offsets=table_arg_offsets) + ser = ArrowStreamGroupSerializer() else: # Each row is a group so do not batch but send one by one. ser = BatchedSerializer(CPickleSerializer(), 1) @@ -2291,16 +2387,46 @@ def evaluate(*args: pd.Series, num_rows=1): cleanup = getattr(udtf, "cleanup") if hasattr(udtf, "cleanup") else None def mapper(_, it): + def convert_output(output_iter): + """Convert (pandas.DataFrame, arrow_type, spark_type) to Arrow RecordBatch.""" + for df, arrow_type, spark_type in output_iter: + yield PandasBatchTransformer.to_arrow( + (df, arrow_type, spark_type), + timezone=runner_conf.timezone, + safecheck=runner_conf.safecheck, + int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled, + as_struct=True, + assign_cols_by_name=False, + arrow_cast=True, + ignore_unexpected_complex_type_values=True, + error_class="UDTF_ARROW_TYPE_CAST_ERROR", + ) + try: for a in it: + # Convert Arrow columns to pandas Series for legacy UDTF + # Use struct_in_pandas="row" to match master's behavior for table arguments + series = [ + ArrowArrayToPandasConversion.convert_legacy( + a.column(i), + f.dataType, + timezone=runner_conf.timezone, + struct_in_pandas="row", + ndarray_as_list=True, + ) + for i, f in enumerate(input_type) + ] # The eval function yields an iterator. Each element produced by this # iterator is a tuple in the form of (pandas.DataFrame, arrow_return_type). - yield from eval(*[a[o] for o in args_kwargs_offsets], num_rows=len(a[0])) + num_rows = a.num_rows if not series else len(series[0]) + yield from convert_output( + eval(*[series[o] for o in args_kwargs_offsets], num_rows=num_rows) + ) if terminate is not None: - yield from terminate() + yield from convert_output(terminate()) except SkipRestOfInputTableException: if terminate is not None: - yield from terminate() + yield from convert_output(terminate()) finally: if cleanup is not None: cleanup() @@ -2389,12 +2515,10 @@ def check_return_value(res): yield row def convert_to_arrow(data: Iterable): + """ + Convert Python data to coerced Arrow batches. + """ data = list(check_return_value(data)) - if len(data) == 0: - # Return one empty RecordBatch to match the left side of the lateral join - return [ - pa.RecordBatch.from_pylist(data, schema=pa.schema(list(arrow_return_type))) - ] def raise_conversion_error(original_exception): raise PySparkRuntimeError( @@ -2429,18 +2553,28 @@ def raise_conversion_error(original_exception): except Exception as e: raise_conversion_error(e) - return verify_result(table).to_batches() + # Enforce schema (reorder + coerce) for each batch + table = verify_result(table) + batches = table.to_batches() + if len(batches) == 0: + # Empty table - create empty batch for lateral join semantics + batches = [pa.RecordBatch.from_pylist([], schema=table.schema)] + for batch in batches: + yield ArrowBatchTransformer.enforce_schema( + batch, + return_type, + prefer_large_var_types=runner_conf.use_large_var_types, + ) def evaluate(*args: list, num_rows=1): if len(args) == 0: for _ in range(num_rows): - for batch in convert_to_arrow(func()): - yield batch, arrow_return_type - + yield from map(ArrowBatchTransformer.wrap_struct, convert_to_arrow(func())) else: for row in zip(*args): - for batch in convert_to_arrow(func(*row)): - yield batch, arrow_return_type + yield from map( + ArrowBatchTransformer.wrap_struct, convert_to_arrow(func(*row)) + ) return evaluate @@ -2473,8 +2607,8 @@ def mapper(_, it): else column.to_pylist() for column, conv in zip(a.columns, converters) ] - # The eval function yields an iterator. Each element produced by this - # iterator is a tuple in the form of (pyarrow.RecordBatch, arrow_return_type). + # The eval function yields wrapped RecordBatches ready for serialization. + # Each batch has been type-coerced and wrapped into a struct column. yield from eval(*[pylist[o] for o in args_kwargs_offsets], num_rows=a.num_rows) if terminate is not None: yield from terminate() @@ -2508,9 +2642,6 @@ def verify_result(result): "func": f.__name__, }, ) - - # We verify the type of the result and do type corerion - # in the serializer return result # Wrap the exception thrown from the UDTF in a PySparkRuntimeError. @@ -2541,30 +2672,37 @@ def check_return_value(res): return iter([]) def convert_to_arrow(data: Iterable): - data_iter = check_return_value(data) - - # Handle PyArrow Tables/RecordBatches directly + """ + Convert PyArrow data to coerced Arrow batches. + """ is_empty = True - for item in data_iter: + for item in check_return_value(data): is_empty = False if isinstance(item, pa.Table): - yield from item.to_batches() + batches = item.to_batches() or [ + pa.RecordBatch.from_pylist([], schema=item.schema) + ] elif isinstance(item, pa.RecordBatch): - yield item + batches = [item] else: - # Arrow UDTF should only return Arrow types (RecordBatch/Table) raise PySparkRuntimeError( errorClass="UDTF_ARROW_TYPE_CONVERSION_ERROR", messageParameters={}, ) + yield from ( + ArrowBatchTransformer.enforce_schema( + verify_result(batch), + return_type, + prefer_large_var_types=runner_conf.use_large_var_types, + ) + for batch in batches + ) if is_empty: yield pa.RecordBatch.from_pylist([], schema=pa.schema(list(arrow_return_type))) def evaluate(*args: pa.RecordBatch): - # For Arrow UDTFs, unpack the RecordBatches and pass them to the function - for batch in convert_to_arrow(func(*args)): - yield verify_result(batch), arrow_return_type + yield from map(ArrowBatchTransformer.wrap_struct, convert_to_arrow(func(*args))) return evaluate @@ -2582,9 +2720,16 @@ def evaluate(*args: pa.RecordBatch): def mapper(_, it): try: - for a in it: - # For PyArrow UDTFs, pass RecordBatches directly (no row conversion needed) - yield from eval(*[a[o] for o in args_kwargs_offsets]) + for batch in it: + # Flatten struct columns at table_arg_offsets into RecordBatches, + # keep other columns as Arrays + args = [ + ArrowBatchTransformer.flatten_struct(batch, column_index=i) + if i in table_arg_offsets + else batch.column(i) + for i in range(batch.num_columns) + ] + yield from eval(*[args[o] for o in args_kwargs_offsets]) if terminate is not None: yield from terminate() except SkipRestOfInputTableException: @@ -2702,6 +2847,13 @@ def mapper(_, it): def read_udfs(pickleSer, infile, eval_type, runner_conf): state_server_port = None key_schema = None + + # Check if legacy arrow batched UDF mode (needs to read input type from stream) + is_legacy_arrow_batched_udf = ( + eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF + and runner_conf.use_legacy_pandas_udf_conversion + ) + if eval_type in ( PythonEvalType.SQL_ARROW_BATCHED_UDF, PythonEvalType.SQL_SCALAR_PANDAS_UDF, @@ -2731,60 +2883,37 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf): state_object_schema = None if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: state_object_schema = StructType.fromJson(json.loads(utf8_deserializer.loads(infile))) - elif ( - eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF - or eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF - or eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF - or eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF + elif eval_type in ( + PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF, + PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF, + PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF, + PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF, ): state_server_port = read_int(infile) if state_server_port == -1: state_server_port = utf8_deserializer.loads(infile) key_schema = StructType.fromJson(json.loads(utf8_deserializer.loads(infile))) - # NOTE: if timezone is set here, that implies respectSessionTimeZone is True - if ( - eval_type == PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF - or eval_type == PythonEvalType.SQL_GROUPED_MAP_ARROW_ITER_UDF - ): - ser = ArrowStreamGroupUDFSerializer(runner_conf.assign_cols_by_name) - elif eval_type in ( + # Grouped UDFs: num_dfs=1 (single DataFrame per group) + if eval_type in ( + PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF, + PythonEvalType.SQL_GROUPED_MAP_ARROW_ITER_UDF, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF, PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF, PythonEvalType.SQL_GROUPED_AGG_ARROW_ITER_UDF, - PythonEvalType.SQL_WINDOW_AGG_ARROW_UDF, - ): - ser = ArrowStreamAggArrowUDFSerializer(safecheck=True, arrow_cast=True) - elif eval_type in ( PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, PythonEvalType.SQL_GROUPED_AGG_PANDAS_ITER_UDF, + PythonEvalType.SQL_WINDOW_AGG_ARROW_UDF, PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF, ): - ser = ArrowStreamAggPandasUDFSerializer( - runner_conf.timezone, - runner_conf.safecheck, - runner_conf.assign_cols_by_name, - runner_conf.int_to_decimal_coercion_enabled, - ) - elif ( - eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF - or eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF + ser = ArrowStreamGroupSerializer(num_dfs=1) + # Cogrouped UDFs: num_dfs=2 (two DataFrames per group) + elif eval_type in ( + PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF, + PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, ): - ser = GroupPandasUDFSerializer( - runner_conf.timezone, - runner_conf.safecheck, - runner_conf.assign_cols_by_name, - runner_conf.int_to_decimal_coercion_enabled, - ) - elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF: - ser = CogroupArrowUDFSerializer(runner_conf.assign_cols_by_name) - elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: - ser = CogroupPandasUDFSerializer( - runner_conf.timezone, - runner_conf.safecheck, - runner_conf.assign_cols_by_name, - int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled, - arrow_cast=True, - ) + ser = ArrowStreamGroupSerializer(num_dfs=2) elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: ser = ApplyInPandasWithStateSerializer( runner_conf.timezone, @@ -2813,20 +2942,12 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf): runner_conf.arrow_max_bytes_per_batch, int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled, ) - elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF: - ser = TransformWithStateInPySparkRowSerializer(runner_conf.arrow_max_records_per_batch) - elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF: - ser = TransformWithStateInPySparkRowInitStateSerializer( - runner_conf.arrow_max_records_per_batch - ) - elif eval_type == PythonEvalType.SQL_MAP_ARROW_ITER_UDF: - ser = ArrowStreamUDFSerializer() elif eval_type in ( - PythonEvalType.SQL_SCALAR_ARROW_UDF, - PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF, + PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF, + PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF, ): - # Arrow cast and safe check are always enabled - ser = ArrowStreamArrowUDFSerializer(safecheck=True, arrow_cast=True) + ser = ArrowStreamGroupSerializer() + # SQL_ARROW_BATCHED_UDF with new conversion uses ArrowBatchUDFSerializer elif ( eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF and not runner_conf.use_legacy_pandas_udf_conversion @@ -2839,36 +2960,16 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf): runner_conf.binary_as_bytes, ) else: - # Scalar Pandas UDF handles struct type arguments as pandas DataFrames instead of - # pandas Series. See SPARK-27240. - df_for_struct = ( - eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF - or eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF - or eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF - ) - # Arrow-optimized Python UDF takes a struct type argument as a Row - struct_in_pandas = ( - "row" if eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF else "dict" - ) - ndarray_as_list = eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF - # Arrow-optimized Python UDF takes input types - input_type = ( - _parse_datatype_json_string(utf8_deserializer.loads(infile)) - if eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF - else None - ) - - ser = ArrowStreamPandasUDFSerializer( - runner_conf.timezone, - runner_conf.safecheck, - runner_conf.assign_cols_by_name, - df_for_struct, - struct_in_pandas, - ndarray_as_list, - True, - input_type, - int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled, - ) + # All other UDFs use ArrowStreamGroupSerializer with num_dfs=0 + # This includes: SQL_SCALAR_PANDAS_UDF, SQL_SCALAR_PANDAS_ITER_UDF, + # SQL_MAP_PANDAS_ITER_UDF, SQL_MAP_ARROW_ITER_UDF, SQL_SCALAR_ARROW_UDF, + # SQL_SCALAR_ARROW_ITER_UDF, SQL_ARROW_BATCHED_UDF (legacy mode) + if is_legacy_arrow_batched_udf: + # Read input type from stream - needed for proper UDT conversion + legacy_input_type = _parse_datatype_json_string(utf8_deserializer.loads(infile)) + else: + legacy_input_type = None + ser = ArrowStreamGroupSerializer() else: batch_size = int(os.environ.get("PYTHON_UDF_BATCH_SIZE", "100")) ser = BatchedSerializer(CPickleSerializer(), batch_size) @@ -2880,77 +2981,6 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf): for i in range(num_udfs) ] - is_scalar_iter = eval_type in ( - PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF, - PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF, - ) - is_map_pandas_iter = eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF - is_map_arrow_iter = eval_type == PythonEvalType.SQL_MAP_ARROW_ITER_UDF - - if is_scalar_iter or is_map_pandas_iter or is_map_arrow_iter: - # TODO: Better error message for num_udfs != 1 - if is_scalar_iter: - assert num_udfs == 1, "One SCALAR_ITER UDF expected here." - if is_map_pandas_iter: - assert num_udfs == 1, "One MAP_PANDAS_ITER UDF expected here." - if is_map_arrow_iter: - assert num_udfs == 1, "One MAP_ARROW_ITER UDF expected here." - - arg_offsets, udf = udfs[0] - - def func(_, iterator): - num_input_rows = 0 - - def map_batch(batch): - nonlocal num_input_rows - - udf_args = [batch[offset] for offset in arg_offsets] - num_input_rows += len(udf_args[0]) - if len(udf_args) == 1: - return udf_args[0] - else: - return tuple(udf_args) - - iterator = map(map_batch, iterator) - result_iter = udf(iterator) - - num_output_rows = 0 - for result_batch, result_type in result_iter: - num_output_rows += len(result_batch) - # This check is for Scalar Iterator UDF to fail fast. - # The length of the entire input can only be explicitly known - # by consuming the input iterator in user side. Therefore, - # it's very unlikely the output length is higher than - # input length. - if is_scalar_iter and num_output_rows > num_input_rows: - raise PySparkRuntimeError( - errorClass="PANDAS_UDF_OUTPUT_EXCEEDS_INPUT_ROWS", messageParameters={} - ) - yield (result_batch, result_type) - - if is_scalar_iter: - try: - next(iterator) - except StopIteration: - pass - else: - raise PySparkRuntimeError( - errorClass="STOP_ITERATION_OCCURRED_FROM_SCALAR_ITER_PANDAS_UDF", - messageParameters={}, - ) - - if num_output_rows != num_input_rows: - raise PySparkRuntimeError( - errorClass="RESULT_LENGTH_MISMATCH_FOR_SCALAR_ITER_PANDAS_UDF", - messageParameters={ - "output_length": str(num_output_rows), - "input_length": str(num_input_rows), - }, - ) - - # profiling is not supported for UDF - return func, None, ser, ser - def extract_key_value_indexes(grouped_arg_offsets): """ Helper function to extract the key and value indexes from arg_offsets for the grouped and @@ -2980,6 +3010,124 @@ def extract_key_value_indexes(grouped_arg_offsets): idx += offsets_len return parsed + if eval_type in ( + PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF, + PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF, + ): + # Scalar iterator UDFs: require input/output length matching + assert num_udfs == 1, "One SCALAR_ITER UDF expected here." + arg_offsets, udf = udfs[0] + needs_pandas = eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF + + def input_mapper(batch): + if needs_pandas: + return ArrowBatchTransformer.to_pandas( + batch, timezone=runner_conf.timezone, df_for_struct=True + ) + return batch + + def output_mapper(item): + result, arrow_return_type = item + if needs_pandas: + return PandasBatchTransformer.to_arrow( + [(result, arrow_return_type)], + timezone=runner_conf.timezone, + safecheck=runner_conf.safecheck, + int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled, + arrow_cast=True, + struct_in_pandas="dict", + ) + return ArrowBatchTransformer.zip_batches([(result, arrow_return_type)]) + + def func(_, iterator): + num_input_rows = 0 + + def mapper(batch): + nonlocal num_input_rows + converted = input_mapper(batch) + udf_args = [converted[offset] for offset in arg_offsets] + num_input_rows += len(udf_args[0]) + return udf_args[0] if len(udf_args) == 1 else tuple(udf_args) + + iterator = map(mapper, iterator) + result_iter = udf(iterator) + + num_output_rows = 0 + for item in result_iter: + result_batch = output_mapper(item) + num_output_rows += len(result_batch.column(0)) + if num_output_rows > num_input_rows: + raise PySparkRuntimeError( + errorClass="PANDAS_UDF_OUTPUT_EXCEEDS_INPUT_ROWS", messageParameters={} + ) + yield result_batch + + # Verify consumed all input and output length matches + try: + next(iterator) + except StopIteration: + pass + else: + raise PySparkRuntimeError( + errorClass="STOP_ITERATION_OCCURRED_FROM_SCALAR_ITER_PANDAS_UDF", + messageParameters={}, + ) + if num_output_rows != num_input_rows: + raise PySparkRuntimeError( + errorClass="RESULT_LENGTH_MISMATCH_FOR_SCALAR_ITER_PANDAS_UDF", + messageParameters={ + "output_length": str(num_output_rows), + "input_length": str(num_input_rows), + }, + ) + + return func, None, ser, ser + + elif eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF: + # Map pandas iterator UDF: no length verification needed + assert num_udfs == 1, "One MAP_PANDAS_ITER UDF expected here." + arg_offsets, udf = udfs[0] + + def input_mapper(batch): + return ArrowBatchTransformer.to_pandas( + batch, timezone=runner_conf.timezone, df_for_struct=True + ) + + def output_mapper(item): + result, arrow_return_type = item + return PandasBatchTransformer.to_arrow( + [(result, arrow_return_type)], + timezone=runner_conf.timezone, + safecheck=runner_conf.safecheck, + int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled, + arrow_cast=True, + struct_in_pandas="dict", + ) + + def func(_, iterator): + def mapper(batch): + converted = input_mapper(batch) + udf_args = [converted[offset] for offset in arg_offsets] + return udf_args[0] if len(udf_args) == 1 else tuple(udf_args) + + iterator = map(mapper, iterator) + result_iter = udf(iterator) + return map(output_mapper, result_iter) + + return func, None, ser, ser + + elif eval_type == PythonEvalType.SQL_MAP_ARROW_ITER_UDF: + # Map arrow iterator UDF (mapInArrow): wrapper handles conversion + assert num_udfs == 1, "One MAP_ARROW_ITER UDF expected here." + _, udf = udfs[0] + + def func(_, iterator): + # Flatten struct → call wrapper → wrapper returns RecordBatch with wrap_struct + iterator = map(ArrowBatchTransformer.flatten_struct, iterator) + return udf(iterator) + + return func, None, ser, ser + if ( eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF or eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF @@ -2995,21 +3143,28 @@ def extract_key_value_indexes(grouped_arg_offsets): arg_offsets, f = udfs[0] parsed_offsets = extract_key_value_indexes(arg_offsets) - def mapper(series_iter): - # Need to materialize the first series list to get the keys - first_series_list = next(series_iter) + def mapper(batch_iter): + import pandas as pd + + # Convert Arrow batches to pandas DataFrames + df_iter = map( + lambda batch: pd.concat( + ArrowBatchTransformer.to_pandas(batch, timezone=runner_conf.timezone), axis=1 + ), + batch_iter, + ) - # Extract key Series from the first batch - key_series = [first_series_list[o] for o in parsed_offsets[0][0]] + # Materialize first batch to extract keys (keys are same for all batches in group) + first_df = next(df_iter) + key_series = [first_df.iloc[:, o] for o in parsed_offsets[0][0]] - # Create generator for value Series lists (one list per batch) - value_series_gen = ( - [series_list[o] for o in parsed_offsets[0][1]] - for series_list in itertools.chain((first_series_list,), series_iter) + # Create generator for value DataFrames: yields DataFrame per batch + value_df_gen = ( + df.iloc[:, parsed_offsets[0][1]] for df in itertools.chain((first_df,), df_iter) ) - # Flatten one level: yield from wrapper to return Iterator[[(df, arrow_type)]] - yield from f(key_series, value_series_gen) + # Wrapper yields wrapped RecordBatches (one or more per group) + yield from f(key_series, value_df_gen) elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF: # We assume there is only one UDF here because grouped map doesn't @@ -3081,31 +3236,71 @@ def values_gen(): # support combining multiple UDFs. assert num_udfs == 1 + import pyarrow as pa + # See TransformWithStateInPySparkExec for how arg_offsets are used to # distinguish between grouping attributes and data attributes arg_offsets, f = udfs[0] parsed_offsets = extract_key_value_indexes(arg_offsets) - ser.key_offsets = parsed_offsets[0][0] + key_offsets = parsed_offsets[0][0] stateful_processor_api_client = StatefulProcessorApiClient(state_server_port, key_schema) - def mapper(a): - mode = a[0] + def func(_, batches): + """ + Convert raw Arrow batches to (mode, key, rows) tuples, call mapper, + and convert output back to Arrow batches. + """ - if mode == TransformWithStateInPandasFuncMode.PROCESS_DATA: - key = a[1] - values = a[2] + def generate_data_batches(): + """Convert Arrow batches to (key, row) pairs.""" + for batch in batches: + DataRow = Row(*batch.schema.names) + num_cols = batch.num_columns + for row_idx in range(batch.num_rows): + row_key = tuple(batch[o][row_idx].as_py() for o in key_offsets) + row = DataRow(*(batch.column(i)[row_idx].as_py() for i in range(num_cols))) + yield row_key, row + + def generate_tuples(): + """Group by key and yield (mode, key, rows) tuples.""" + data_batches = generate_data_batches() + for k, g in itertools.groupby(data_batches, key=lambda x: x[0]): + chained = itertools.chain(g) + chained_values = map(lambda x: x[1], chained) + yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k, chained_values) + yield (TransformWithStateInPandasFuncMode.PROCESS_TIMER, None, None) + yield (TransformWithStateInPandasFuncMode.COMPLETE, None, None) + + def mapper(a): + mode = a[0] + if mode == TransformWithStateInPandasFuncMode.PROCESS_DATA: + key = a[1] + values = a[2] + return f(stateful_processor_api_client, mode, key, values) + else: + return f(stateful_processor_api_client, mode, None, iter([])) + + def convert_output(results): + """Convert UDF output [(iter[Row], pdf_type)] to Arrow batches.""" + for [(iter_row, pdf_type)] in results: + rows_as_dict = [row.asDict(True) for row in iter_row] + record_batch = pa.RecordBatch.from_pylist( + rows_as_dict, schema=pa.schema(list(pdf_type)) + ) + yield ArrowBatchTransformer.wrap_struct(record_batch) - # This must be generator comprehension - do not materialize. - return f(stateful_processor_api_client, mode, key, values) - else: - # mode == PROCESS_TIMER or mode == COMPLETE - return f(stateful_processor_api_client, mode, None, iter([])) + return convert_output(map(mapper, generate_tuples())) + + # Return early since func is already defined + return func, None, ser, ser elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF: # We assume there is only one UDF here because grouped map doesn't # support combining multiple UDFs. assert num_udfs == 1 + import pyarrow as pa + # See TransformWithStateInPandasExec for how arg_offsets are used to # distinguish between grouping attributes and data attributes arg_offsets, f = udfs[0] @@ -3115,22 +3310,116 @@ def mapper(a): # [initStateGroupingOffsets, dedupInitDataOffsets] # ] parsed_offsets = extract_key_value_indexes(arg_offsets) - ser.key_offsets = parsed_offsets[0][0] - ser.init_key_offsets = parsed_offsets[1][0] + key_offsets = parsed_offsets[0][0] + init_key_offsets = parsed_offsets[1][0] + arrow_max_records_per_batch = runner_conf.arrow_max_records_per_batch stateful_processor_api_client = StatefulProcessorApiClient(state_server_port, key_schema) - def mapper(a): - mode = a[0] + def func(_, batches): + """ + Convert raw Arrow batches to (mode, key, (input_rows, init_rows)) tuples, + call mapper, and convert output back to Arrow batches. + """ - if mode == TransformWithStateInPandasFuncMode.PROCESS_DATA: - key = a[1] - values = a[2] + def extract_rows(cur_batch, col_name, offsets): + """Extract rows from a struct column in the batch.""" + data_column = cur_batch.column(cur_batch.schema.get_field_index(col_name)) - # This must be generator comprehension - do not materialize. - return f(stateful_processor_api_client, mode, key, values) - else: - # mode == PROCESS_TIMER or mode == COMPLETE - return f(stateful_processor_api_client, mode, None, iter([])) + # Check if the entire column is null + if data_column.null_count == len(data_column): + return None + + # Flatten struct column to RecordBatch + flattened = pa.RecordBatch.from_arrays( + data_column.flatten(), schema=pa.schema(data_column.type) + ) + + if flattened.num_rows == 0: + return None + + DataRow = Row(*flattened.schema.names) + + def row_iterator(): + for row_idx in range(flattened.num_rows): + key = tuple(flattened[o][row_idx].as_py() for o in offsets) + row = DataRow( + *( + flattened.column(i)[row_idx].as_py() + for i in range(flattened.num_columns) + ) + ) + yield (key, row) + + return row_iterator() + + def generate_data_batches(): + """ + Convert Arrow batches to (key, input_row, init_row) triples. + Each batch contains either inputData or initState, not both. + """ + for batch in batches: + input_result = extract_rows(batch, "inputData", key_offsets) + init_result = extract_rows(batch, "initState", init_key_offsets) + + assert not (input_result is not None and init_result is not None) + + if input_result is not None: + for key, input_data_row in input_result: + yield (key, input_data_row, None) + elif init_result is not None: + for key, init_state_row in init_result: + yield (key, None, init_state_row) + + def generate_tuples(): + """Group by key and yield (mode, key, (input_rows, init_rows)) tuples.""" + data_batches = generate_data_batches() + + for k, g in itertools.groupby(data_batches, key=lambda x: x[0]): + input_rows = [] + init_rows = [] + + for batch_key, input_row, init_row in g: + if input_row is not None: + input_rows.append(input_row) + if init_row is not None: + init_rows.append(init_row) + + total_len = len(input_rows) + len(init_rows) + if total_len >= arrow_max_records_per_batch: + ret_tuple = (iter(input_rows), iter(init_rows)) + yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k, ret_tuple) + input_rows = [] + init_rows = [] + + if input_rows or init_rows: + ret_tuple = (iter(input_rows), iter(init_rows)) + yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k, ret_tuple) + + yield (TransformWithStateInPandasFuncMode.PROCESS_TIMER, None, None) + yield (TransformWithStateInPandasFuncMode.COMPLETE, None, None) + + def mapper(a): + mode = a[0] + if mode == TransformWithStateInPandasFuncMode.PROCESS_DATA: + key = a[1] + values = a[2] + return f(stateful_processor_api_client, mode, key, values) + else: + return f(stateful_processor_api_client, mode, None, iter([])) + + def convert_output(results): + """Convert UDF output [(iter[Row], pdf_type)] to Arrow batches.""" + for [(iter_row, pdf_type)] in results: + rows_as_dict = [row.asDict(True) for row in iter_row] + record_batch = pa.RecordBatch.from_pylist( + rows_as_dict, schema=pa.schema(list(pdf_type)) + ) + yield ArrowBatchTransformer.wrap_struct(record_batch) + + return convert_output(map(mapper, generate_tuples())) + + # Return early since func is already defined + return func, None, ser, ser elif ( eval_type == PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF @@ -3147,12 +3436,6 @@ def mapper(a): arg_offsets, f = udfs[0] parsed_offsets = extract_key_value_indexes(arg_offsets) - def batch_from_offset(batch, offsets): - return pa.RecordBatch.from_arrays( - arrays=[batch.columns[o] for o in offsets], - names=[batch.schema.names[o] for o in offsets], - ) - def mapper(batches): # Flatten struct column into separate columns flattened = map(ArrowBatchTransformer.flatten_struct, batches) @@ -3160,9 +3443,9 @@ def mapper(batches): # Need to materialize the first batch to get the keys first_batch = next(flattened) - keys = batch_from_offset(first_batch, parsed_offsets[0][0]) + keys = first_batch.select(parsed_offsets[0][0]) value_batches = ( - batch_from_offset(batch, parsed_offsets[0][1]) + batch.select(parsed_offsets[0][1]) for batch in itertools.chain((first_batch,), flattened) ) @@ -3206,6 +3489,8 @@ def mapper(a): return f(keys, vals, state) elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: + import pyarrow as pa + # We assume there is only one UDF here because cogrouped map doesn't # support combining multiple UDFs. assert num_udfs == 1 @@ -3213,11 +3498,26 @@ def mapper(a): parsed_offsets = extract_key_value_indexes(arg_offsets) + # Cogrouped UDF receives tuple of (list of RecordBatches, list of RecordBatches) def mapper(a): - df1_keys = [a[0][o] for o in parsed_offsets[0][0]] - df1_vals = [a[0][o] for o in parsed_offsets[0][1]] - df2_keys = [a[1][o] for o in parsed_offsets[1][0]] - df2_vals = [a[1][o] for o in parsed_offsets[1][1]] + # a is tuple[list[pa.RecordBatch], list[pa.RecordBatch]] + left_batches, right_batches = a[0], a[1] + + # Convert batches to tables (batches already have flattened struct columns) + left_table = pa.Table.from_batches(left_batches) if left_batches else pa.table({}) + right_table = pa.Table.from_batches(right_batches) if right_batches else pa.table({}) + + # Convert tables to pandas Series lists + left_series = ArrowBatchTransformer.to_pandas(left_table, timezone=runner_conf.timezone) + right_series = ArrowBatchTransformer.to_pandas( + right_table, timezone=runner_conf.timezone + ) + + df1_keys = [left_series[o] for o in parsed_offsets[0][0]] + df1_vals = [left_series[o] for o in parsed_offsets[0][1]] + df2_keys = [right_series[o] for o in parsed_offsets[1][0]] + df2_vals = [right_series[o] for o in parsed_offsets[1][1]] + return f(df1_keys, df1_vals, df2_keys, df2_vals) elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF: @@ -3230,20 +3530,11 @@ def mapper(a): parsed_offsets = extract_key_value_indexes(arg_offsets) - def batch_from_offset(batch, offsets): - return pa.RecordBatch.from_arrays( - arrays=[batch.columns[o] for o in offsets], - names=[batch.schema.names[o] for o in offsets], - ) - - def table_from_batches(batches, offsets): - return pa.Table.from_batches([batch_from_offset(batch, offsets) for batch in batches]) - def mapper(a): - df1_keys = table_from_batches(a[0], parsed_offsets[0][0]) - df1_vals = table_from_batches(a[0], parsed_offsets[0][1]) - df2_keys = table_from_batches(a[1], parsed_offsets[1][0]) - df2_vals = table_from_batches(a[1], parsed_offsets[1][1]) + df1_keys = pa.Table.from_batches([batch.select(parsed_offsets[0][0]) for batch in a[0]]) + df1_vals = pa.Table.from_batches([batch.select(parsed_offsets[0][1]) for batch in a[0]]) + df2_keys = pa.Table.from_batches([batch.select(parsed_offsets[1][0]) for batch in a[1]]) + df2_vals = pa.Table.from_batches([batch.select(parsed_offsets[1][1]) for batch in a[1]]) return f(df1_keys, df1_vals, df2_keys, df2_vals) elif eval_type == PythonEvalType.SQL_GROUPED_AGG_ARROW_ITER_UDF: @@ -3253,13 +3544,15 @@ def mapper(a): arg_offsets, f = udfs[0] - # Convert to iterator of batches: Iterator[pa.Array] for single column, + # Convert to iterator of arrays: Iterator[pa.Array] for single column, # or Iterator[Tuple[pa.Array, ...]] for multiple columns + # a is Iterator[pa.RecordBatch], extract columns from each batch def mapper(a): if len(arg_offsets) == 1: - batch_iter = (batch_columns[arg_offsets[0]] for batch_columns in a) + batch_iter = (batch.column(arg_offsets[0]) for batch in a) else: - batch_iter = (tuple(batch_columns[o] for o in arg_offsets) for batch_columns in a) + batch_iter = (tuple(batch.column(o) for o in arg_offsets) for batch in a) + # f returns pa.RecordBatch directly return f(batch_iter) elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_ITER_UDF: @@ -3269,21 +3562,23 @@ def mapper(a): arg_offsets, f = udfs[0] - # Convert to iterator of pandas Series: + # POC: Convert raw Arrow batches to pandas Series in worker layer # - Iterator[pd.Series] for single column # - Iterator[Tuple[pd.Series, ...]] for multiple columns def mapper(batch_iter): - # batch_iter is Iterator[Tuple[pd.Series, ...]] where each tuple represents one batch - # Convert to Iterator[pd.Series] or Iterator[Tuple[pd.Series, ...]] based on arg_offsets - if len(arg_offsets) == 1: - # Single column: Iterator[Tuple[pd.Series, ...]] -> Iterator[pd.Series] - series_iter = (batch_series[arg_offsets[0]] for batch_series in batch_iter) - else: - # Multiple columns: Iterator[Tuple[pd.Series, ...]] -> - # Iterator[Tuple[pd.Series, ...]] - series_iter = ( - tuple(batch_series[o] for o in arg_offsets) for batch_series in batch_iter - ) + # batch_iter is Iterator[pa.RecordBatch] (raw batches from serializer) + # Convert to pandas Series lists, then select columns based on arg_offsets + series_lists = map( + lambda batch: ArrowBatchTransformer.to_pandas(batch, timezone=runner_conf.timezone), + batch_iter, + ) + series_iter = ( + series_list[arg_offsets[0]] + if len(arg_offsets) == 1 + else tuple(series_list[o] for o in arg_offsets) + for series_list in series_lists + ) + # Wrapper returns RecordBatch directly (conversion done in wrapper) return f(series_iter) elif eval_type in ( @@ -3293,83 +3588,158 @@ def mapper(batch_iter): import pyarrow as pa # For SQL_GROUPED_AGG_ARROW_UDF and SQL_WINDOW_AGG_ARROW_UDF, - # convert iterator of batch columns to a concatenated RecordBatch + # convert iterator of RecordBatch to a concatenated RecordBatch def mapper(a): - # a is Iterator[Tuple[pa.Array, ...]] - convert to RecordBatch - batches = [] - for batch_columns in a: - # batch_columns is Tuple[pa.Array, ...] - convert to RecordBatch - batch = pa.RecordBatch.from_arrays( - batch_columns, names=["_%d" % i for i in range(len(batch_columns))] - ) - batches.append(batch) - - # Concatenate all batches into one - if hasattr(pa, "concat_batches"): - concatenated_batch = pa.concat_batches(batches) - else: - # pyarrow.concat_batches not supported in old versions - concatenated_batch = pa.RecordBatch.from_struct_array( - pa.concat_arrays([b.to_struct_array() for b in batches]) - ) + # a is Iterator[pa.RecordBatch] - collect and concatenate all batches + concatenated_batch = ArrowBatchTransformer.concat_batches(list(a)) - # Extract series using offsets (concatenated_batch.columns[o] gives pa.Array) - result = tuple( + # Each UDF returns pa.RecordBatch + result_batches = [ f(*[concatenated_batch.columns[o] for o in arg_offsets]) for arg_offsets, f in udfs - ) - # In the special case of a single UDF this will return a single result rather - # than a tuple of results; this is the format that the JVM side expects. - if len(result) == 1: - return result[0] - else: - return result + ] + + # Merge RecordBatches from all UDFs horizontally + return ArrowBatchTransformer.zip_batches(result_batches) elif eval_type in ( PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF, ): - import pandas as pd - # For SQL_GROUPED_AGG_PANDAS_UDF and SQL_WINDOW_AGG_PANDAS_UDF, - # convert iterator of batch tuples to concatenated pandas Series + # batch_iter is now Iterator[pa.RecordBatch] (raw batches from serializer) + # Convert to pandas and concatenate into single Series per column def mapper(batch_iter): - # batch_iter is Iterator[Tuple[pd.Series, ...]] where each tuple represents one batch - # Collect all batches and concatenate into single Series per column - batches = list(batch_iter) - if not batches: - # Empty batches - determine num_columns from all UDFs' arg_offsets - all_offsets = [o for arg_offsets, _ in udfs for o in arg_offsets] - num_columns = max(all_offsets) + 1 if all_offsets else 0 - concatenated = [pd.Series(dtype=object) for _ in range(num_columns)] - else: - # Use actual number of columns from the first batch - num_columns = len(batches[0]) + import pandas as pd + + # Convert raw Arrow batches to pandas Series in worker layer + series_batches = [ + ArrowBatchTransformer.to_pandas(batch, timezone=runner_conf.timezone) + for batch in batch_iter + ] + # Concatenate all batches column-wise into single Series per column + if series_batches: + num_columns = len(series_batches[0]) concatenated = [ - pd.concat([batch[i] for batch in batches], ignore_index=True) + pd.concat([batch[i] for batch in series_batches], ignore_index=True) for i in range(num_columns) ] + else: + all_offsets = [o for arg_offsets, _ in udfs for o in arg_offsets] + num_columns = max(all_offsets) + 1 if all_offsets else 0 + concatenated = [pd.Series(dtype=object) for _ in range(num_columns)] + + # Each UDF returns pa.RecordBatch (conversion done in wrapper) + result_batches = [f(*[concatenated[o] for o in arg_offsets]) for arg_offsets, f in udfs] - result = tuple(f(*[concatenated[o] for o in arg_offsets]) for arg_offsets, f in udfs) - # In the special case of a single UDF this will return a single result rather - # than a tuple of results; this is the format that the JVM side expects. - if len(result) == 1: - return result[0] + # Merge all RecordBatches horizontally + return ArrowBatchTransformer.zip_batches(result_batches) + + elif eval_type == PythonEvalType.SQL_BATCHED_UDF: + # Regular batched UDF uses pickle serialization, no pyarrow needed + # Input: individual rows (list of column values) + # Output: individual results + def mapper(row): + # row is a list of column values [col0, col1, col2, ...] + # Each UDF wrapper takes its columns and returns the result + results = [f(*[row[o] for o in arg_offsets]) for arg_offsets, f in udfs] + if len(results) == 1: + return results[0] else: - return result + return tuple(results) else: + import pyarrow as pa - def mapper(a): - result = tuple(f(*[a[o] for o in arg_offsets]) for arg_offsets, f in udfs) - # In the special case of a single UDF this will return a single result rather - # than a tuple of results; this is the format that the JVM side expects. - if len(result) == 1: - return result[0] + # Determine UDF type for appropriate handling + is_arrow_batched_udf = eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF + is_scalar_pandas_udf = eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF + is_scalar_arrow_udf = eval_type == PythonEvalType.SQL_SCALAR_ARROW_UDF + + def mapper(batch): + """Process a single batch through all UDFs. + + For scalar pandas/arrow UDFs: + 1. Input conversion: Arrow → Pandas (for pandas UDF only) + 2. Call each wrapper, collect (result, arrow_type) pairs + 3. Output conversion: Pandas/Arrow → Arrow RecordBatch + """ + # Step 1: Input conversion (Arrow → Pandas for pandas UDF or legacy arrow batched UDF) + if is_scalar_pandas_udf: + data = ArrowBatchTransformer.to_pandas( + batch, timezone=runner_conf.timezone, df_for_struct=True + ) + elif is_legacy_arrow_batched_udf: + # Legacy mode: convert Arrow to pandas Series for wrap_arrow_batch_udf_legacy + # which expects *args: pd.Series input + # Use legacy_input_type schema for proper UDT conversion + data = ArrowBatchTransformer.to_pandas( + batch, + timezone=runner_conf.timezone, + schema=legacy_input_type, + struct_in_pandas="row", + ndarray_as_list=True, + ) else: - return result + data = batch + + # Step 2: Call each wrapper, collect results + # Wrappers return (result, arrow_return_type) tuples + results_with_types = [f(*[data[o] for o in arg_offsets]) for arg_offsets, f in udfs] + + # Step 3: Output conversion + # For arrow batched UDF (non-legacy), return raw result (serializer handles conversion) + if is_arrow_batched_udf and not is_legacy_arrow_batched_udf: + if len(results_with_types) == 1: + return results_with_types[0] + else: + return results_with_types + + # For legacy arrow batched UDF, convert pandas results to Arrow + # Results are (pandas.Series, arrow_type, spark_type) tuples + if is_legacy_arrow_batched_udf: + return PandasBatchTransformer.to_arrow( + results_with_types, + timezone=runner_conf.timezone, + safecheck=runner_conf.safecheck, + int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled, + arrow_cast=True, + struct_in_pandas="row", + ) + + # For scalar pandas UDF, convert all results to Arrow at once + if is_scalar_pandas_udf: + return PandasBatchTransformer.to_arrow( + results_with_types, + timezone=runner_conf.timezone, + safecheck=runner_conf.safecheck, + int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled, + arrow_cast=True, + struct_in_pandas="dict", + ) + + # For scalar arrow UDF, zip Arrow arrays into a RecordBatch + if is_scalar_arrow_udf: + return ArrowBatchTransformer.zip_batches(results_with_types) + + # For grouped/cogrouped map UDFs: + # All wrappers yield batches, so mapper returns generator → need flatten + if eval_type in ( + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF, + PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF, + PythonEvalType.SQL_GROUPED_MAP_ARROW_ITER_UDF, + PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, + PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF, + ): + + def func(_, it): + # All grouped/cogrouped wrappers yield, so flatten with chain.from_iterable + return itertools.chain.from_iterable(map(mapper, it)) + + else: - def func(_, it): - return map(mapper, it) + def func(_, it): + return map(mapper, it) # profiling is not supported for UDF return func, None, ser, ser